Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions dbos/dbos_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dbos

import (
"context"
"testing"
"time"

Expand Down Expand Up @@ -166,4 +167,105 @@ func TestConfig(t *testing.T) {
assert.Equal(t, "env-only-executor", ctx.GetExecutorID())
})
})

t.Run("SystemDBMigration", func(t *testing.T) {
t.Setenv("DBOS__APPVERSION", "v1.0.0")
t.Setenv("DBOS__APPID", "test-migration")
t.Setenv("DBOS__VMID", "test-executor-id")

ctx, err := NewDBOSContext(Config{
DatabaseURL: databaseURL,
AppName: "test-migration",
})
require.NoError(t, err)
defer func() {
if ctx != nil {
ctx.Shutdown(1 * time.Minute)
}
}()

require.NotNil(t, ctx)

// Get the internal systemDB instance to check tables directly
dbosCtx, ok := ctx.(*dbosContext)
require.True(t, ok, "expected dbosContext")
require.NotNil(t, dbosCtx.systemDB)

sysDB, ok := dbosCtx.systemDB.(*sysDB)
require.True(t, ok, "expected sysDB")

// Verify all expected tables exist and have correct structure
dbCtx := context.Background()

// Test workflow_status table
var exists bool
err = sysDB.pool.QueryRow(dbCtx, "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_schema = 'dbos' AND table_name = 'workflow_status')").Scan(&exists)
require.NoError(t, err)
assert.True(t, exists, "workflow_status table should exist")

// Test operation_outputs table
err = sysDB.pool.QueryRow(dbCtx, "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_schema = 'dbos' AND table_name = 'operation_outputs')").Scan(&exists)
require.NoError(t, err)
assert.True(t, exists, "operation_outputs table should exist")

// Test workflow_events table
err = sysDB.pool.QueryRow(dbCtx, "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_schema = 'dbos' AND table_name = 'workflow_events')").Scan(&exists)
require.NoError(t, err)
assert.True(t, exists, "workflow_events table should exist")

// Test notifications table
err = sysDB.pool.QueryRow(dbCtx, "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_schema = 'dbos' AND table_name = 'notifications')").Scan(&exists)
require.NoError(t, err)
assert.True(t, exists, "notifications table should exist")

// Test that all tables can be queried (empty results expected)
rows, err := sysDB.pool.Query(dbCtx, "SELECT workflow_uuid FROM dbos.workflow_status LIMIT 1")
require.NoError(t, err)
rows.Close()

rows, err = sysDB.pool.Query(dbCtx, "SELECT workflow_uuid FROM dbos.operation_outputs LIMIT 1")
require.NoError(t, err)
rows.Close()

rows, err = sysDB.pool.Query(dbCtx, "SELECT workflow_uuid FROM dbos.workflow_events LIMIT 1")
require.NoError(t, err)
rows.Close()

rows, err = sysDB.pool.Query(dbCtx, "SELECT destination_uuid FROM dbos.notifications LIMIT 1")
require.NoError(t, err)
rows.Close()

// Check that the dbos_migrations table exists and has one row with the correct version
err = sysDB.pool.QueryRow(dbCtx, "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_schema = 'dbos' AND table_name = 'dbos_migrations')").Scan(&exists)
require.NoError(t, err)
assert.True(t, exists, "dbos_migrations table should exist")

// Verify migration version is 1 (after initial migration)
var version int64
var count int
err = sysDB.pool.QueryRow(dbCtx, "SELECT COUNT(*) FROM dbos.dbos_migrations").Scan(&count)
require.NoError(t, err)
assert.Equal(t, 1, count, "dbos_migrations table should have exactly one row")

err = sysDB.pool.QueryRow(dbCtx, "SELECT version FROM dbos.dbos_migrations").Scan(&version)
require.NoError(t, err)
assert.Equal(t, int64(1), version, "migration version should be 1 (after initial migration)")

// Test manual shutdown and recreate
ctx.Shutdown(1 * time.Minute)

// Recreate context - should have no error since DB is already migrated
ctx2, err := NewDBOSContext(Config{
DatabaseURL: databaseURL,
AppName: "test-migration-recreate",
})
require.NoError(t, err)
defer func() {
if ctx2 != nil {
ctx2.Shutdown(1 * time.Minute)
}
}()

require.NotNil(t, ctx2)
})
}
18 changes: 0 additions & 18 deletions dbos/migrations/000001_initial_dbos_schema.down.sql

This file was deleted.

153 changes: 119 additions & 34 deletions dbos/system_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ import (
"embed"
"errors"
"fmt"
"io/fs"
"log/slog"
"net/url"
"path/filepath"
"sort"
"strconv"
"strings"
"sync"
"time"

"github.com/golang-migrate/migrate/v4"
_ "github.com/golang-migrate/migrate/v4/database/pgx/v5"
"github.com/golang-migrate/migrate/v4/source/iofs"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
Expand Down Expand Up @@ -123,7 +123,7 @@ func createDatabaseIfNotExists(ctx context.Context, databaseURL string, logger *
var migrationFiles embed.FS

const (
_DBOS_MIGRATION_TABLE = "dbos_schema_migrations"
_DBOS_MIGRATION_TABLE = "dbos_migrations"

// PostgreSQL error codes
_PG_ERROR_UNIQUE_VIOLATION = "23505"
Expand All @@ -139,52 +139,137 @@ const (
)

func runMigrations(databaseURL string) error {
// Change the driver to pgx5
parsedURL, err := url.Parse(databaseURL)
// Connect to the database
pool, err := pgxpool.New(context.Background(), databaseURL)
if err != nil {
return fmt.Errorf("failed to parse database URL: %v", err)
return fmt.Errorf("failed to create connection pool: %v", err)
}
// Handle various PostgreSQL URL schemes
switch parsedURL.Scheme {
case "postgres", "postgresql":
parsedURL.Scheme = "pgx5"
case "pgx5":
// Already in correct format
default:
return fmt.Errorf("unsupported database URL scheme: %s", parsedURL.Scheme)
defer pool.Close()

// Begin transaction for atomic migration execution
ctx := context.Background()
tx, err := pool.Begin(ctx)
if err != nil {
return fmt.Errorf("failed to begin transaction: %v", err)
}
databaseURL = parsedURL.String()
defer tx.Rollback(ctx)

// Add custom migration table name to avoid conflicts with user migrations
// Check if query parameters already exist
separator := "?"
if parsedURL.RawQuery != "" {
separator = "&"
// Create the DBOS schema if it doesn't exist
_, err = tx.Exec(ctx, "CREATE SCHEMA IF NOT EXISTS dbos")
if err != nil {
return fmt.Errorf("failed to create DBOS schema: %v", err)
}
databaseURL += separator + "x-migrations-table=" + _DBOS_MIGRATION_TABLE

// Create migration source from embedded files
d, err := iofs.New(migrationFiles, "migrations")
// Create the migrations table if it doesn't exist
createTableQuery := fmt.Sprintf(`CREATE TABLE IF NOT EXISTS dbos.%s (
version BIGINT NOT NULL PRIMARY KEY
)`, _DBOS_MIGRATION_TABLE)

_, err = tx.Exec(ctx, createTableQuery)
if err != nil {
return fmt.Errorf("failed to create migration source: %v", err)
return fmt.Errorf("failed to create migrations table: %v", err)
}

// Get current migration version
var currentVersion int64 = 0
query := fmt.Sprintf("SELECT version FROM dbos.%s LIMIT 1", _DBOS_MIGRATION_TABLE)
err = tx.QueryRow(ctx, query).Scan(&currentVersion)
if err != nil && err != pgx.ErrNoRows {
return fmt.Errorf("failed to get current migration version: %v", err)
}

// Create migrator
m, err := migrate.NewWithSourceInstance("iofs", d, databaseURL)
// Read and parse migration files
migrations, err := parseMigrationFiles()
if err != nil {
return fmt.Errorf("failed to create migrator: %v", err)
return fmt.Errorf("failed to parse migration files: %v", err)
}
defer m.Close()

// Run migrations
// FIXME: tolerate errors when the migration is bcz we run an older version of transact
if err := m.Up(); err != nil && err != migrate.ErrNoChange {
return fmt.Errorf("failed to run migrations: %v", err)
// Apply migrations starting from the next version
for _, migration := range migrations {
if migration.version <= currentVersion {
continue
}

// Execute the migration SQL
_, err = tx.Exec(ctx, migration.sql)
if err != nil {
return fmt.Errorf("failed to execute migration %d: %v", migration.version, err)
}

// Update the migration version
if currentVersion == 0 {
// Insert first migration record
insertQuery := fmt.Sprintf("INSERT INTO dbos.%s (version) VALUES ($1)", _DBOS_MIGRATION_TABLE)
_, err = tx.Exec(ctx, insertQuery, migration.version)
} else {
// Update existing migration record
updateQuery := fmt.Sprintf("UPDATE dbos.%s SET version = $1", _DBOS_MIGRATION_TABLE)
_, err = tx.Exec(ctx, updateQuery, migration.version)
}
if err != nil {
return fmt.Errorf("failed to update migration version to %d: %v", migration.version, err)
}

currentVersion = migration.version
}

// Commit the transaction
if err := tx.Commit(ctx); err != nil {
return fmt.Errorf("failed to commit migration transaction: %v", err)
}

return nil
}

type migrationFile struct {
version int64
sql string
}

func parseMigrationFiles() ([]migrationFile, error) {
var migrations []migrationFile

entries, err := fs.ReadDir(migrationFiles, "migrations")
if err != nil {
return nil, fmt.Errorf("failed to read migration directory: %v", err)
}

for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".sql") {
continue
}

// Extract version from filename (e.g., "1_initial_dbos_schema.sql" -> 1)
parts := strings.SplitN(entry.Name(), "_", 2)
if len(parts) < 2 {
continue
}

version, err := strconv.ParseInt(parts[0], 10, 64)
if err != nil {
continue // Skip files with invalid version format
}

// Read migration SQL content
content, err := fs.ReadFile(migrationFiles, filepath.Join("migrations", entry.Name()))
if err != nil {
return nil, fmt.Errorf("failed to read migration file %s: %v", entry.Name(), err)
}

migrations = append(migrations, migrationFile{
version: version,
sql: string(content),
})
}

// Sort migrations by version
sort.Slice(migrations, func(i, j int) bool {
return migrations[i].version < migrations[j].version
})

return migrations, nil
}

// New creates a new SystemDatabase instance and runs migrations
func newSystemDatabase(ctx context.Context, databaseURL string, logger *slog.Logger) (systemDatabase, error) {
// Create the database if it doesn't exist
Expand Down
5 changes: 2 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ toolchain go1.25.0
require (
github.com/docker/docker v28.3.3+incompatible
github.com/docker/go-connections v0.5.0
github.com/golang-migrate/migrate/v4 v4.18.3
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.3
github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa
Expand All @@ -33,14 +32,14 @@ require (
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-viper/mapstructure/v2 v2.2.1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-multierror v1.1.1 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/moby/docker-image-spec v1.3.1 // indirect
github.com/moby/sys/atomicwriter v0.1.0 // indirect
github.com/moby/term v0.5.2 // indirect
github.com/morikuni/aec v1.0.0 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.0 // indirect
github.com/pelletier/go-toml/v2 v2.2.3 // indirect
Expand Down
Loading