Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 34 additions & 8 deletions cmd/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,28 @@ import (
"github.com/printeers/trek/internal/postgres"
)

// ExitError is an error that carries a specific exit code.
type ExitError struct {
Code int
Message string
}

func (e *ExitError) Error() string {
return e.Message
}

// ErrDiffDetected is returned by generate when diff statements are generated.
var ErrDiffDetected = &ExitError{Code: 2, Message: "diff statements detected"}

//nolint:gocognit,cyclop
func NewGenerateCommand() *cobra.Command {
var (
dev bool
cleanup bool
overwrite bool
stdout bool
check bool
dev bool
cleanup bool
overwrite bool
stdout bool
check bool
errorOnDiff bool
)

generateCmd := &cobra.Command{
Expand Down Expand Up @@ -99,7 +113,7 @@ func NewGenerateCommand() *cobra.Command {
}
}

err = runWithStdout(ctx, config, wd, tmpDir, migrationsDir, len(migrationFiles) == 0)
err = runWithStdout(ctx, config, wd, tmpDir, migrationsDir, len(migrationFiles) == 0, errorOnDiff)
if err != nil {
return err
}
Expand All @@ -115,7 +129,7 @@ func NewGenerateCommand() *cobra.Command {
return fmt.Errorf("failed to create temporary directory: %w", err)
}

err = runWithStdout(ctx, config, wd, tmpDir, migrationsDir, len(migrationFiles) == 0)
err = runWithStdout(ctx, config, wd, tmpDir, migrationsDir, len(migrationFiles) == 0, errorOnDiff)
if err != nil {
return err
}
Expand Down Expand Up @@ -156,7 +170,8 @@ func NewGenerateCommand() *cobra.Command {
}

var updated bool
updated, err = runWithFile(ctx, config, wd, tmpDir, migrationsDir, newMigrationFilePath, migrationNumber)
updated, err = runWithFile(
ctx, config, wd, tmpDir, migrationsDir, newMigrationFilePath, migrationNumber, errorOnDiff)
if err != nil {
return err
}
Expand Down Expand Up @@ -200,6 +215,7 @@ func NewGenerateCommand() *cobra.Command {
generateCmd.Flags().BoolVar(&overwrite, "overwrite", false, "Overwrite existing files")
generateCmd.Flags().BoolVar(&stdout, "stdout", false, "Output migration statements to stdout")
generateCmd.Flags().BoolVar(&check, "check", true, "Run checks after generating the migration")
generateCmd.Flags().BoolVar(&errorOnDiff, "error-on-diff", false, "Exit with code 2 if diff statements are generated")

return generateCmd
}
Expand All @@ -222,6 +238,7 @@ func runWithStdout(
tmpDir,
migrationsDir string,
initial bool,
errorOnDiff bool,
) error {
updated, err := checkIfUpdated(config, wd)
if err != nil {
Expand Down Expand Up @@ -313,6 +330,10 @@ func runWithStdout(
fmt.Println("--")
fmt.Println(statements)
fmt.Println("--")

if errorOnDiff && statements != "" {
return ErrDiffDetected
}
}

return nil
Expand All @@ -327,6 +348,7 @@ func runWithFile(
migrationsDir,
newMigrationFilePath string,
migrationNumber uint,
errorOnDiff bool,
) (bool, error) {
updated, err := checkIfUpdated(config, wd)
if err != nil {
Expand Down Expand Up @@ -412,6 +434,10 @@ func runWithFile(
return false, fmt.Errorf("failed to write template files: %w", err)
}

if errorOnDiff && statements != "" {
return true, ErrDiffDetected
}

return true, nil
}

Expand Down
11 changes: 10 additions & 1 deletion cmd/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,16 @@ func NewInitCommand() *cobra.Command {
return fmt.Errorf("failed to create temporary directory: %w", err)
}

_, err = runWithFile(ctx, config, wd, tmpDir, migrationsDir, filepath.Join(migrationsDir, "001_init.up.sql"), 1)
_, err = runWithFile(
ctx,
config,
wd,
tmpDir,
migrationsDir,
filepath.Join(migrationsDir, "001_init.up.sql"),
1,
false,
)
if err != nil {
return fmt.Errorf("failed to generate first migration: %w", err)
}
Expand Down
5 changes: 5 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"errors"
"fmt"
"os"

Expand All @@ -10,6 +11,10 @@ import (
func main() {
if err := cmd.NewRootCommand().Execute(); err != nil {
_, _ = fmt.Fprintln(os.Stderr, err)
var exitErr *cmd.ExitError
if errors.As(err, &exitErr) {
os.Exit(exitErr.Code)
}
os.Exit(1)
}
}
Loading