Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,59 +1,25 @@
package dialectquery

import "strings"
package dialect

// Querier is the interface that wraps the basic methods to create a dialect specific query.
//
// It is intended tio be using with [database.NewStoreFromQuerier] to create a new [database.Store]
// implementation based on a custom querier.
type Querier interface {
// CreateTable returns the SQL query string to create the db version table.
CreateTable(tableName string) string

// InsertVersion returns the SQL query string to insert a new version into the db version table.
InsertVersion(tableName string) string

// DeleteVersion returns the SQL query string to delete a version from the db version table.
DeleteVersion(tableName string) string

// GetMigrationByVersion returns the SQL query string to get a single migration by version.
//
// The query should return the timestamp and is_applied columns.
GetMigrationByVersion(tableName string) string

// ListMigrations returns the SQL query string to list all migrations in descending order by id.
//
// The query should return the version_id and is_applied columns.
ListMigrations(tableName string) string

// GetLatestVersion returns the SQL query string to get the last version_id from the db version
// table. Returns a nullable int64 value.
GetLatestVersion(tableName string) string
}

var _ Querier = (*QueryController)(nil)

type QueryController struct{ Querier }

// NewQueryController returns a new QueryController that wraps the given Querier.
func NewQueryController(querier Querier) *QueryController {
return &QueryController{Querier: querier}
}

// Optional methods

// TableExists returns the SQL query string to check if the version table exists. If the Querier
// does not implement this method, it will return an empty string.
//
// Returns a boolean value.
func (c *QueryController) TableExists(tableName string) string {
if t, ok := c.Querier.(interface{ TableExists(string) string }); ok {
return t.TableExists(tableName)
}
return ""
}

func parseTableIdentifier(name string) (schema, table string) {
schema, table, found := strings.Cut(name, ".")
if !found {
return "", name
}
return schema, table
}
22 changes: 22 additions & 0 deletions database/dialect/querier_extended.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package dialect

// QuerierExtender extends the [Querier] interface with optional database-specific optimizations.
// While not required, implementing these methods can improve performance.
//
// IMPORTANT: This interface may be expanded in future versions. Implementors must be prepared to
// update their implementations when new methods are added.
//
// Example compile-time check:
//
// var _ QuerierExtender = (*CustomQuerierExtended)(nil)
//
// In short, it's exported to allow implementors to have a compile-time check that they are
// implementing the interface correctly.
type QuerierExtender interface {
Querier

// TableExists returns a database-specific SQL query to check if a table exists. For example,
// implementations might query system catalogs like pg_tables or sqlite_master. Return empty
// string if not supported.
TableExists(tableName string) string
}
108 changes: 72 additions & 36 deletions database/dialect.go → database/dialects.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,86 +6,100 @@ import (
"errors"
"fmt"

"github.com/pressly/goose/v3/internal/dialect/dialectquery"
"github.com/pressly/goose/v3/database/dialect"
"github.com/pressly/goose/v3/internal/dialects"
)

// Dialect is the type of database dialect.
type Dialect string

const (
DialectCustom Dialect = ""
DialectClickHouse Dialect = "clickhouse"
DialectMSSQL Dialect = "mssql"
DialectMySQL Dialect = "mysql"
DialectPostgres Dialect = "postgres"
DialectRedshift Dialect = "redshift"
DialectSQLite3 Dialect = "sqlite3"
DialectStarrocks Dialect = "starrocks"
DialectTiDB Dialect = "tidb"
DialectTurso Dialect = "turso"
DialectVertica Dialect = "vertica"
DialectYdB Dialect = "ydb"
DialectStarrocks Dialect = "starrocks"
)

// NewStore returns a new [Store] implementation for the given dialect.
func NewStore(dialect Dialect, tablename string) (Store, error) {
if tablename == "" {
return nil, errors.New("table name must not be empty")
func NewStore(d Dialect, tableName string) (Store, error) {
if d == DialectCustom {
return nil, errors.New("custom dialect is not supported")
}
if dialect == "" {
return nil, errors.New("dialect must not be empty")
}
lookup := map[Dialect]dialectquery.Querier{
DialectClickHouse: &dialectquery.Clickhouse{},
DialectMSSQL: &dialectquery.Sqlserver{},
DialectMySQL: &dialectquery.Mysql{},
DialectPostgres: &dialectquery.Postgres{},
DialectRedshift: &dialectquery.Redshift{},
DialectSQLite3: &dialectquery.Sqlite3{},
DialectTiDB: &dialectquery.Tidb{},
DialectVertica: &dialectquery.Vertica{},
DialectYdB: &dialectquery.Ydb{},
DialectTurso: &dialectquery.Turso{},
DialectStarrocks: &dialectquery.Starrocks{},
}
querier, ok := lookup[dialect]
lookup := map[Dialect]dialect.Querier{
DialectClickHouse: dialects.NewClickhouse(),
DialectMSSQL: dialects.NewSqlserver(),
DialectMySQL: dialects.NewMysql(),
DialectPostgres: dialects.NewPostgres(),
DialectRedshift: dialects.NewRedshift(),
DialectSQLite3: dialects.NewSqlite3(),
DialectStarrocks: dialects.NewStarrocks(),
DialectTiDB: dialects.NewTidb(),
DialectTurso: dialects.NewTurso(),
DialectVertica: dialects.NewVertica(),
DialectYdB: dialects.NewYDB(),
}
querier, ok := lookup[d]
if !ok {
return nil, fmt.Errorf("unknown dialect: %q", dialect)
return nil, fmt.Errorf("unknown dialect: %q", d)
}
return NewStoreFromQuerier(tableName, querier)
}

// NewStoreFromQuerier returns a new [Store] implementation for the given querier.
//
// Most of the time you should use [NewStore] instead of this function, as it will automatically
// create a dialect-specific querier for you. This function is useful if you want to use a custom
// querier that is not part of the standard dialects.
func NewStoreFromQuerier(tableName string, querier dialect.Querier) (Store, error) {
if tableName == "" {
return nil, errors.New("table name must not be empty")
}
if querier == nil {
return nil, errors.New("querier must not be nil")
}
return &store{
tablename: tablename,
querier: dialectquery.NewQueryController(querier),
tableName: tableName,
querier: newQueryController(querier),
}, nil
}

type store struct {
tablename string
querier *dialectquery.QueryController
tableName string
querier *queryController
}

var _ Store = (*store)(nil)

func (s *store) Tablename() string {
return s.tablename
return s.tableName
}

func (s *store) CreateVersionTable(ctx context.Context, db DBTxConn) error {
q := s.querier.CreateTable(s.tablename)
q := s.querier.CreateTable(s.tableName)
if _, err := db.ExecContext(ctx, q); err != nil {
return fmt.Errorf("failed to create version table %q: %w", s.tablename, err)
return fmt.Errorf("failed to create version table %q: %w", s.tableName, err)
}
return nil
}

func (s *store) Insert(ctx context.Context, db DBTxConn, req InsertRequest) error {
q := s.querier.InsertVersion(s.tablename)
q := s.querier.InsertVersion(s.tableName)
if _, err := db.ExecContext(ctx, q, req.Version, true); err != nil {
return fmt.Errorf("failed to insert version %d: %w", req.Version, err)
}
return nil
}

func (s *store) Delete(ctx context.Context, db DBTxConn, version int64) error {
q := s.querier.DeleteVersion(s.tablename)
q := s.querier.DeleteVersion(s.tableName)
if _, err := db.ExecContext(ctx, q, version); err != nil {
return fmt.Errorf("failed to delete version %d: %w", version, err)
}
Expand All @@ -97,7 +111,7 @@ func (s *store) GetMigration(
db DBTxConn,
version int64,
) (*GetMigrationResult, error) {
q := s.querier.GetMigrationByVersion(s.tablename)
q := s.querier.GetMigrationByVersion(s.tableName)
var result GetMigrationResult
if err := db.QueryRowContext(ctx, q, version).Scan(
&result.Timestamp,
Expand All @@ -112,7 +126,7 @@ func (s *store) GetMigration(
}

func (s *store) GetLatestVersion(ctx context.Context, db DBTxConn) (int64, error) {
q := s.querier.GetLatestVersion(s.tablename)
q := s.querier.GetLatestVersion(s.tableName)
var version sql.NullInt64
if err := db.QueryRowContext(ctx, q).Scan(&version); err != nil {
return -1, fmt.Errorf("failed to get latest version: %w", err)
Expand All @@ -127,7 +141,7 @@ func (s *store) ListMigrations(
ctx context.Context,
db DBTxConn,
) ([]*ListMigrationsResult, error) {
q := s.querier.ListMigrations(s.tablename)
q := s.querier.ListMigrations(s.tableName)
rows, err := db.QueryContext(ctx, q)
if err != nil {
return nil, fmt.Errorf("failed to list migrations: %w", err)
Expand Down Expand Up @@ -158,7 +172,7 @@ func (s *store) ListMigrations(
//

func (s *store) TableExists(ctx context.Context, db DBTxConn) (bool, error) {
q := s.querier.TableExists(s.tablename)
q := s.querier.TableExists(s.tableName)
if q == "" {
return false, errors.ErrUnsupported
}
Expand All @@ -170,3 +184,25 @@ func (s *store) TableExists(ctx context.Context, db DBTxConn) (bool, error) {
}
return exists, nil
}

var _ dialect.Querier = (*queryController)(nil)

type queryController struct{ dialect.Querier }

// newQueryController returns a new QueryController that wraps the given Querier.
func newQueryController(querier dialect.Querier) *queryController {
return &queryController{Querier: querier}
}

// Optional methods

// TableExists returns the SQL query string to check if the version table exists. If the Querier
// does not implement this method, it will return an empty string.
//
// Returns a boolean value.
func (c *queryController) TableExists(tableName string) string {
if t, ok := c.Querier.(interface{ TableExists(string) string }); ok {
return t.TableExists(tableName)
}
return ""
}
4 changes: 2 additions & 2 deletions database/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ func testStore(
alreadyExists func(t *testing.T, err error),
) {
const (
tablename = "test_goose_db_version"
tableName = "test_goose_db_version"
)
store, err := database.NewStore(d, tablename)
store, err := database.NewStore(d, tableName)
require.NoError(t, err)
// Create the version table.
err = runTx(ctx, db, func(tx *sql.Tx) error {
Expand Down
36 changes: 19 additions & 17 deletions dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,61 +4,63 @@ import (
"fmt"

"github.com/pressly/goose/v3/database"
"github.com/pressly/goose/v3/internal/dialect"
"github.com/pressly/goose/v3/internal/legacystore"
)

// Dialect is the type of database dialect. It is an alias for [database.Dialect].
type Dialect = database.Dialect

const (
DialectCustom Dialect = database.DialectCustom
DialectClickHouse Dialect = database.DialectClickHouse
DialectMSSQL Dialect = database.DialectMSSQL
DialectMySQL Dialect = database.DialectMySQL
DialectPostgres Dialect = database.DialectPostgres
DialectRedshift Dialect = database.DialectRedshift
DialectSQLite3 Dialect = database.DialectSQLite3
DialectStarrocks Dialect = database.DialectStarrocks
DialectTiDB Dialect = database.DialectTiDB
DialectTurso Dialect = database.DialectTurso
DialectVertica Dialect = database.DialectVertica
DialectYdB Dialect = database.DialectYdB
DialectStarrocks Dialect = database.DialectStarrocks
)

func init() {
store, _ = dialect.NewStore(dialect.Postgres)
store, _ = legacystore.NewStore(DialectPostgres)
}

var store dialect.Store
var store legacystore.Store

// SetDialect sets the dialect to use for the goose package.
func SetDialect(s string) error {
var d dialect.Dialect
var d Dialect
switch s {
case "postgres", "pgx":
d = dialect.Postgres
d = DialectPostgres
case "mysql":
d = dialect.Mysql
d = DialectMySQL
case "sqlite3", "sqlite":
d = dialect.Sqlite3
d = DialectSQLite3
case "mssql", "azuresql", "sqlserver":
d = dialect.Sqlserver
d = DialectMSSQL
case "redshift":
d = dialect.Redshift
d = DialectRedshift
case "tidb":
d = dialect.Tidb
d = DialectTiDB
case "clickhouse":
d = dialect.Clickhouse
d = DialectClickHouse
case "vertica":
d = dialect.Vertica
d = DialectVertica
case "ydb":
d = dialect.Ydb
d = DialectYdB
case "turso":
d = dialect.Turso
d = DialectTurso
case "starrocks":
d = dialect.Starrocks
d = DialectStarrocks
default:
return fmt.Errorf("%q: unknown dialect", s)
}
var err error
store, err = dialect.NewStore(d)
store, err = legacystore.NewStore(d)
return err
}
File renamed without changes.
7 changes: 0 additions & 7 deletions internal/dialect/dialectquery/turso.go

This file was deleted.

Loading
Loading