Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions gotests.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"sort"
"sync"

"github.com/cweill/gotests/internal/gomod"
"github.com/cweill/gotests/internal/goparser"
"github.com/cweill/gotests/internal/input"
"github.com/cweill/gotests/internal/models"
Expand All @@ -21,6 +22,7 @@ type Options struct {
Only *regexp.Regexp // Includes only functions that match.
Exclude *regexp.Regexp // Excludes functions that match.
Exported bool // Include only exported methods
PackageTest bool // Adds _test package suffix for tests
PrintInputs bool // Print function parameters in error messages
Subtests bool // Print tests using Go 1.7 subtests
Parallel bool // Print tests that runs the subtests in parallel.
Expand Down Expand Up @@ -121,6 +123,16 @@ func generateTest(src models.Path, files []models.Path, opt *Options) (*Generate
if err != nil {
return nil, err
}
if opt.PackageTest && opt.Exported {
fullImportPath, fullImportPathErr := gomod.GetFullImportPath(string(src))
if fullImportPathErr != nil {
return nil, fullImportPathErr
}
h.Imports = append(h.Imports, &models.Import{
Name: h.Package,
Path: fmt.Sprintf("%q", fullImportPath),
})
}
funcs := testableFuncs(sr.Funcs, opt.Only, opt.Exclude, opt.Exported, tf)
if len(funcs) == 0 {
return nil, nil
Expand All @@ -129,6 +141,7 @@ func generateTest(src models.Path, files []models.Path, opt *Options) (*Generate
options := output.Options{
PrintInputs: opt.PrintInputs,
Subtests: opt.Subtests,
PackageTest: opt.PackageTest,
Parallel: opt.Parallel,
Named: opt.Named,
UseGoCmp: opt.UseGoCmp,
Expand All @@ -142,6 +155,9 @@ func generateTest(src models.Path, files []models.Path, opt *Options) (*Generate
AIMinCases: opt.AIMinCases,
AIMaxCases: opt.AIMaxCases,
}
if opt.PackageTest {
h.PackageTest = true
}

b, err := options.Process(h, funcs)
if err != nil {
Expand Down
13 changes: 12 additions & 1 deletion gotests_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ func TestGenerateTests(t *testing.T) {
only *regexp.Regexp
excl *regexp.Regexp
exported bool
packageTest bool
printInputs bool
subtests bool
parallel bool
Expand Down Expand Up @@ -438,6 +439,15 @@ func TestGenerateTests(t *testing.T) {
},
want: mustReadAndFormatGoFile(t, "testdata/goldens/multiple_functions_filtering_exported.go"),
},
{
name: "Multiple functions filtering exported and _test package",
args: args{
srcPath: `testdata/test_filter.go`,
exported: true,
packageTest: true,
},
want: mustReadAndFormatGoFile(t, "testdata/goldens/multiple_functions_filtering_exported_with_packagetest.go"),
},
{
name: "Multiple functions filtering exported with only",
args: args{
Expand Down Expand Up @@ -891,6 +901,7 @@ func TestGenerateTests(t *testing.T) {
Only: tt.args.only,
Exclude: tt.args.excl,
Exported: tt.args.exported,
PackageTest: tt.args.packageTest,
PrintInputs: tt.args.printInputs,
Subtests: tt.args.subtests,
Parallel: tt.args.parallel,
Expand Down Expand Up @@ -948,7 +959,7 @@ func mustReadAndFormatGoFile(t *testing.T, filename string) string {

func outputResult(t *testing.T, tmpDir, testName string, got []byte) {
tmpResult := path.Join(tmpDir, toSnakeCase(testName)+".go")
if err := ioutil.WriteFile(tmpResult, got, 0644); err != nil {
if err := ioutil.WriteFile(tmpResult, got, 0o644); err != nil {
t.Errorf("ioutil.WriteFile: %v", err)
}
t.Logf("%s", tmpResult)
Expand Down
2 changes: 1 addition & 1 deletion internal/ai/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ func TestE2E_OllamaGeneration_ValidatesStructure(t *testing.T) {

// Render test function with same parameters as CLI uses
// (printInputs=false, subtests=true, named=false, parallel=false, useGoCmp=false)
if err := r.TestFunction(&buf, targetFunc, false, true, false, false, false, nil, aiCases); err != nil {
if err := r.TestFunction(&buf, targetFunc, false, true, false, false, false, nil, aiCases, false, ""); err != nil {
t.Fatalf("Failed to render test function: %v", err)
}

Expand Down
41 changes: 41 additions & 0 deletions internal/gomod/example_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package gomod_test

import (
"fmt"
"log"

"github.com/cweill/gotests/internal/gomod"
)

func ExampleGetFullImportPath_file() {
// Get import path for a specific Go file
importPath, err := gomod.GetFullImportPath("gomod.go")
if err != nil {
log.Fatal(err)
}

fmt.Println(importPath)
// Output: github.com/cweill/gotests/internal/gomod
}

func ExampleGetFullImportPath_directory() {
// Get import path for a package directory
importPath, err := gomod.GetFullImportPath(".")
if err != nil {
log.Fatal(err)
}

fmt.Println(importPath)
// Output: github.com/cweill/gotests/internal/gomod
}

func ExampleGetFullImportPath_moduleRoot() {
// Get import path for the module root directory
importPath, err := gomod.GetFullImportPath("../..")
if err != nil {
log.Fatal(err)
}

fmt.Println(importPath)
// Output: github.com/cweill/gotests
}
106 changes: 106 additions & 0 deletions internal/gomod/gomod.go
Copy link
Author

@IlyasYOY IlyasYOY Oct 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I must mention that the gomod was generated by AI and reviewed by myself.

Everything else was done manually.

Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package gomod

import (
"bufio"
"fmt"
"os"
"path/filepath"
"strings"
)

// GetFullImportPath resolves the full Go import path for any file or directory
// within a Go module. Returns the complete import path like "github.com/user/repo/pkg".
//
// startAt can be either:
// - A Go source file path (e.g., "/path/to/project/main.go")
// - A directory path (e.g., "/path/to/project/pkg")
// - An absolute or relative path
//
// Returns an error if:
// - No go.mod found in the directory tree
// - go.mod is malformed or missing module directive
// - Path resolution fails
func GetFullImportPath(startAt string) (string, error) {
absPath, err := filepath.Abs(startAt)
if err != nil {
return "", fmt.Errorf("failed to get absolute path for %s: %w", startAt, err)
}

// If it's a file, get its directory
if info, err := os.Stat(absPath); err == nil && !info.IsDir() {
absPath = filepath.Dir(absPath)
}

modRoot, err := findGoMod(absPath)
if err != nil {
return "", err
}

modulePath, err := parseModulePath(modRoot)
if err != nil {
return "", err
}

relPath, err := filepath.Rel(modRoot, absPath)
if err != nil {
return "", fmt.Errorf("failed to calculate relative path from %s to %s: %w", modRoot, absPath, err)
}

if relPath == "." {
return modulePath, nil
}

return filepath.Join(modulePath, relPath), nil
}

// findGoMod walks up the directory tree from startDir to find a go.mod file.
// Returns the directory containing go.mod, or an error if not found.
func findGoMod(startDir string) (string, error) {
current := startDir

for {
modPath := filepath.Join(current, "go.mod")
if _, err := os.Stat(modPath); err == nil {
return current, nil
}

parent := filepath.Dir(current)
if parent == current {
// Reached root directory
break
}
current = parent
}

return "", fmt.Errorf("go.mod not found in %s or any parent directory", startDir)
}

// parseModulePath reads the go.mod file in modRoot and extracts the module path.
// Returns the module path or an error if parsing fails.
func parseModulePath(modRoot string) (string, error) {
modFile := filepath.Join(modRoot, "go.mod")

file, err := os.Open(modFile)
if err != nil {
return "", fmt.Errorf("failed to open go.mod at %s: %w", modFile, err)
}
defer file.Close()

scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.HasPrefix(line, "module ") {
modulePath := strings.TrimSpace(line[7:]) // Remove "module " prefix
if modulePath == "" {
return "", fmt.Errorf("empty module path in %s", modFile)
}
return modulePath, nil
}
}

if err := scanner.Err(); err != nil {
return "", fmt.Errorf("error reading go.mod at %s: %w", modFile, err)
}

return "", fmt.Errorf("module directive not found in %s", modFile)
}
Loading