diff --git a/.vscode/settings.json b/.vscode/settings.json index 84f4acf..e097931 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -23,7 +23,6 @@ "go.formatFlags": [ "-extra" ], - "go.formatTool": "gofumpt", "go.lintTool": "golangci-lint-v2", "go.useLanguageServer": true, "gopls": { @@ -40,5 +39,31 @@ "tests" ], "python.testing.pytestEnabled": true, - "python.testing.unittestEnabled": false + "python.testing.unittestEnabled": false, + "cSpell.words": [ + "bson", + "clustersync", + "cmdutil", + "codegen", + "colls", + "connstring", + "contextcheck", + "Debugf", + "errgroup", + "errorlint", + "Infof", + "keygen", + "mapstructure", + "nolint", + "opencode", + "pcsm", + "pipefail", + "readconcern", + "readpref", + "Warnf", + "wrapcheck", + "Wrapf", + "writeconcern", + "zerolog" + ] } diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..ab7db0f --- /dev/null +++ b/config/config.go @@ -0,0 +1,205 @@ +// Package config provides configuration management for PCSM using Viper. +package config + +import ( + "context" + "math" + "os" + "slices" + "strings" + "time" + + "github.com/dustin/go-humanize" + "github.com/go-viper/mapstructure/v2" + "github.com/spf13/cobra" + "github.com/spf13/viper" + + "github.com/percona/percona-clustersync-mongodb/errors" + "github.com/percona/percona-clustersync-mongodb/log" +) + +// Config holds all PCSM configuration. +type Config struct { + Port int `mapstructure:"port"` + Source string `mapstructure:"source"` + Target string `mapstructure:"target"` + + Log LogConfig `mapstructure:",squash"` + + MongoDB MongoDBConfig `mapstructure:",squash"` + + UseCollectionBulkWrite bool `mapstructure:"use-collection-bulk-write"` + + Clone CloneConfig `mapstructure:",squash"` + + // hidden startup flags + Start bool `mapstructure:"start"` + ResetState bool `mapstructure:"reset-state"` + PauseOnInitialSync bool `mapstructure:"pause-on-initial-sync"` +} + +// LogConfig holds logging configuration. +type LogConfig struct { + Level string `mapstructure:"log-level"` + JSON bool `mapstructure:"log-json"` + NoColor bool `mapstructure:"log-no-color"` +} + +// MongoDBConfig holds MongoDB client configuration. +type MongoDBConfig struct { + OperationTimeout time.Duration `mapstructure:"mongodb-operation-timeout"` + TargetCompressors []string `mapstructure:"dev-target-client-compressors"` +} + +// CloneConfig holds clone operation configuration. +type CloneConfig struct { + // NumParallelCollections is the number of collections to clone in parallel. + // 0 means auto (calculated at runtime). + NumParallelCollections int `mapstructure:"clone-num-parallel-collections"` + // NumReadWorkers is the number of read workers during clone. + // 0 means auto (calculated at runtime). + NumReadWorkers int `mapstructure:"clone-num-read-workers"` + // NumInsertWorkers is the number of insert workers during clone. + // 0 means auto (calculated at runtime). + NumInsertWorkers int `mapstructure:"clone-num-insert-workers"` + // SegmentSize is the segment size for clone operations (e.g., "500MB", "1GiB"). + // Empty string means auto (calculated at runtime for each collection). + SegmentSize string `mapstructure:"clone-segment-size"` + // ReadBatchSize is the read batch size during clone (e.g., "16MiB", "100MB"). + // Empty string means auto (calculated at runtime for each collection). + ReadBatchSize string `mapstructure:"clone-read-batch-size"` +} + +// Load initializes Viper and populates the provided Config. +func Load(cmd *cobra.Command, cfg *Config) error { + viper.SetEnvPrefix("PCSM") + viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_")) + viper.AutomaticEnv() + + if cmd.PersistentFlags() != nil { + _ = viper.BindPFlags(cmd.PersistentFlags()) + } + + if cmd.Flags() != nil { + _ = viper.BindPFlags(cmd.Flags()) + } + + bindEnvVars() + + err := viper.Unmarshal(cfg, viper.DecodeHook( + mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + ), + )) + if err != nil { + return errors.Wrap(err, "unmarshal config") + } + + cfg.MongoDB.TargetCompressors = filterCompressors(cfg.MongoDB.TargetCompressors) + + if viper.GetBool("no-color") { + cfg.Log.NoColor = true + } + + return nil +} + +// WarnDeprecatedEnvVars logs warnings for any deprecated environment variables that are set. +// Expects the logger to be initialized. +func WarnDeprecatedEnvVars(ctx context.Context) { + deprecated := map[string]string{ + "PLM_MONGODB_CLI_OPERATION_TIMEOUT": "PCSM_MONGODB_OPERATION_TIMEOUT", + "PCSM_NO_COLOR": "PCSM_LOG_NO_COLOR", + } + + for old, replacement := range deprecated { + if _, ok := os.LookupEnv(old); ok { + log.Ctx(ctx).Warnf( + "Environment variable %s is deprecated; use %s instead", + old, replacement, + ) + } + } +} + +func bindEnvVars() { + _ = viper.BindEnv("port", "PCSM_PORT") + + _ = viper.BindEnv("source", "PCSM_SOURCE_URI") + _ = viper.BindEnv("target", "PCSM_TARGET_URI") + + _ = viper.BindEnv("log-level", "PCSM_LOG_LEVEL") + _ = viper.BindEnv("log-json", "PCSM_LOG_JSON") + _ = viper.BindEnv("log-no-color", + "PCSM_LOG_NO_COLOR", + "PCSM_NO_COLOR", // deprecated + ) + + _ = viper.BindEnv("mongodb-operation-timeout", + "PCSM_MONGODB_OPERATION_TIMEOUT", + "PLM_MONGODB_CLI_OPERATION_TIMEOUT", // deprecated + ) + + _ = viper.BindEnv("use-collection-bulk-write", "PCSM_USE_COLLECTION_BULK_WRITE") + + _ = viper.BindEnv("dev-target-client-compressors", "PCSM_DEV_TARGET_CLIENT_COMPRESSORS") + + _ = viper.BindEnv("clone-num-parallel-collections", "PCSM_CLONE_NUM_PARALLEL_COLLECTIONS") + _ = viper.BindEnv("clone-num-read-workers", "PCSM_CLONE_NUM_READ_WORKERS") + _ = viper.BindEnv("clone-num-insert-workers", "PCSM_CLONE_NUM_INSERT_WORKERS") + _ = viper.BindEnv("clone-segment-size", "PCSM_CLONE_SEGMENT_SIZE") + _ = viper.BindEnv("clone-read-batch-size", "PCSM_CLONE_READ_BATCH_SIZE") +} + +//nolint:gochecknoglobals +var allowedCompressors = []string{"zstd", "zlib", "snappy"} + +func filterCompressors(compressors []string) []string { + if len(compressors) == 0 { + return nil + } + + filtered := make([]string, 0, len(allowedCompressors)) + + for _, c := range compressors { + c = strings.TrimSpace(c) + if slices.Contains(allowedCompressors, c) && !slices.Contains(filtered, c) { + filtered = append(filtered, c) + } + } + + return filtered +} + +// ParseAndValidateCloneSegmentSize parses a byte size string and validates it. +// It allows 0 (auto) or values within [MinCloneSegmentSizeBytes, MaxCloneSegmentSizeBytes]. +func ParseAndValidateCloneSegmentSize(value string) (int64, error) { + sizeBytes, err := humanize.ParseBytes(value) + if err != nil { + return 0, errors.Wrapf(err, "invalid cloneSegmentSize value: %s", value) + } + + err = ValidateCloneSegmentSize(sizeBytes) + if err != nil { + return 0, err + } + + return int64(min(sizeBytes, math.MaxInt64)), nil //nolint:gosec +} + +// ParseAndValidateCloneReadBatchSize parses a byte size string and validates it. +// It allows 0 (auto) or values within [[MinCloneReadBatchSizeBytes], [MaxCloneReadBatchSizeBytes]]. +func ParseAndValidateCloneReadBatchSize(value string) (int32, error) { + sizeBytes, err := humanize.ParseBytes(value) + if err != nil { + return 0, errors.Wrapf(err, "invalid cloneReadBatchSize value: %s", value) + } + + err = ValidateCloneReadBatchSize(sizeBytes) + if err != nil { + return 0, err + } + + return int32(min(sizeBytes, math.MaxInt32)), nil //nolint:gosec +} diff --git a/config/config_test.go b/config/config_test.go new file mode 100644 index 0000000..03e7d6c --- /dev/null +++ b/config/config_test.go @@ -0,0 +1,148 @@ +package config_test + +import ( + "fmt" + "testing" + + "github.com/dustin/go-humanize" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/percona/percona-clustersync-mongodb/config" +) + +func TestParseAndValidateCloneSegmentSize(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + value string + want int64 + wantErr string + }{ + { + name: "valid size 500MB (above minimum)", + value: "500MB", + want: 500 * humanize.MByte, + }, + { + name: "valid size 1GiB", + value: "1GiB", + want: humanize.GiByte, + }, + { + name: "zero value (auto)", + value: "0", + want: 0, + }, + { + name: "below minimum (100MB)", + value: "100MB", + wantErr: "cloneSegmentSize must be at least", + }, + { + name: "above maximum", + value: "100GiB", + wantErr: "cloneSegmentSize must be at most", + }, + { + name: "at minimum boundary (using exact bytes)", + value: fmt.Sprintf("%dB", config.MinCloneSegmentSizeBytes), + want: int64(config.MinCloneSegmentSizeBytes), + }, + { + name: "at maximum boundary", + value: "64GiB", + want: int64(config.MaxCloneSegmentSizeBytes), + }, + { + name: "invalid format", + value: "abc", + wantErr: "invalid cloneSegmentSize value", + }, + { + name: "empty string", + value: "", + wantErr: "invalid cloneSegmentSize value", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got, err := config.ParseAndValidateCloneSegmentSize(tt.value) + + if tt.wantErr == "" { + require.NoError(t, err) + assert.Equal(t, tt.want, got) + } else { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + } + }) + } +} + +func TestParseAndValidateCloneReadBatchSize(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + value string + want int32 + wantErr string + }{ + { + name: "valid size 16MiB", + value: "16MiB", + want: 16 * humanize.MiByte, + }, + { + name: "valid size 48MB", + value: "48MB", + want: 48 * humanize.MByte, + }, + { + name: "zero value (auto)", + value: "0", + want: 0, + }, + { + name: "below minimum", + value: "1KB", + wantErr: "cloneReadBatchSize must be at least", + }, + { + name: "at minimum boundary (using exact bytes)", + value: fmt.Sprintf("%dB", config.MinCloneReadBatchSizeBytes), + want: config.MinCloneReadBatchSizeBytes, + }, + { + name: "invalid format", + value: "xyz", + wantErr: "invalid cloneReadBatchSize value", + }, + { + name: "empty string", + value: "", + wantErr: "invalid cloneReadBatchSize value", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got, err := config.ParseAndValidateCloneReadBatchSize(tt.value) + + if tt.wantErr == "" { + require.NoError(t, err) + assert.Equal(t, tt.want, got) + } else { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + } + }) + } +} diff --git a/config/const.go b/config/const.go index 66ea49d..99820df 100644 --- a/config/const.go +++ b/config/const.go @@ -44,10 +44,9 @@ const ( DisconnectTimeout = 5 * time.Second // CloseCursorTimeout is the timeout duration for closing cursor. CloseCursorTimeout = 10 * time.Second - // DefaultMongoDBCliOperationTimeout is the default timeout duration for MongoDB client - // operations like insert, update, delete, etc. It can be overridden via - // environment variable (see config.OperationMongoDBCliTimeout()). - DefaultMongoDBCliOperationTimeout = 5 * time.Minute + // DefaultMongoDBOperationTimeout is the default timeout for MongoDB client operations. + // Override via --mongodb-operation-timeout flag or PCSM_MONGODB_OPERATION_TIMEOUT env var. + DefaultMongoDBOperationTimeout = 5 * time.Minute ) // Change stream and replication settings. diff --git a/config/validate.go b/config/validate.go new file mode 100644 index 0000000..4fbbeaa --- /dev/null +++ b/config/validate.go @@ -0,0 +1,79 @@ +package config + +import ( + "github.com/dustin/go-humanize" + + "github.com/percona/percona-clustersync-mongodb/errors" +) + +// DefaultServerPort is the default port for the PCSM HTTP server. +const DefaultServerPort = 2242 + +// Validate validates the Config for required fields and value ranges. +func Validate(cfg *Config) error { + port := cfg.Port + if port == 0 { + port = DefaultServerPort + } + + if port <= 1024 || port > 65535 { + return errors.New("port value is outside the supported range [1024 - 65535]") + } + + switch { + case cfg.Source == "" && cfg.Target == "": + return errors.New("source URI and target URI are empty") + case cfg.Source == "": + return errors.New("source URI is empty") + case cfg.Target == "": + return errors.New("target URI is empty") + case cfg.Source == cfg.Target: + return errors.New("source URI and target URI are identical") + } + + return nil +} + +// ValidateCloneSegmentSize validates a clone segment size value in bytes. +// It allows 0 (auto) or values within [MinCloneSegmentSizeBytes, MaxCloneSegmentSizeBytes]. +func ValidateCloneSegmentSize(sizeBytes uint64) error { + if sizeBytes == 0 { + return nil // 0 means auto + } + + if sizeBytes < MinCloneSegmentSizeBytes { + return errors.Errorf("cloneSegmentSize must be at least %s, got %s", + humanize.Bytes(MinCloneSegmentSizeBytes), + humanize.Bytes(sizeBytes)) + } + + if sizeBytes > MaxCloneSegmentSizeBytes { + return errors.Errorf("cloneSegmentSize must be at most %s, got %s", + humanize.Bytes(MaxCloneSegmentSizeBytes), + humanize.Bytes(sizeBytes)) + } + + return nil +} + +// ValidateCloneReadBatchSize validates a clone read batch size value in bytes. +// It allows 0 (auto) or values within [MinCloneReadBatchSizeBytes, MaxCloneReadBatchSizeBytes]. +func ValidateCloneReadBatchSize(sizeBytes uint64) error { + if sizeBytes == 0 { + return nil // 0 means auto + } + + if sizeBytes < uint64(MinCloneReadBatchSizeBytes) { + return errors.Errorf("cloneReadBatchSize must be at least %s, got %s", + humanize.Bytes(uint64(MinCloneReadBatchSizeBytes)), + humanize.Bytes(sizeBytes)) + } + + if sizeBytes > uint64(MaxCloneReadBatchSizeBytes) { + return errors.Errorf("cloneReadBatchSize must be at most %s, got %s", + humanize.Bytes(uint64(MaxCloneReadBatchSizeBytes)), + humanize.Bytes(sizeBytes)) + } + + return nil +} diff --git a/config/validate_test.go b/config/validate_test.go new file mode 100644 index 0000000..414f87a --- /dev/null +++ b/config/validate_test.go @@ -0,0 +1,297 @@ +package config_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/percona/percona-clustersync-mongodb/config" +) + +func TestValidate(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *config.Config + wantErr string + }{ + { + name: "valid config", + cfg: &config.Config{ + Port: 8080, + Source: "mongodb://source:27017", + Target: "mongodb://target:27017", + }, + wantErr: "", + }, + { + name: "port zero uses default - valid", + cfg: &config.Config{ + Port: 0, + Source: "mongodb://source:27017", + Target: "mongodb://target:27017", + }, + wantErr: "", + }, + { + name: "port at lower bound (1025) - valid", + cfg: &config.Config{ + Port: 1025, + Source: "mongodb://source:27017", + Target: "mongodb://target:27017", + }, + wantErr: "", + }, + { + name: "port at upper bound (65535) - valid", + cfg: &config.Config{ + Port: 65535, + Source: "mongodb://source:27017", + Target: "mongodb://target:27017", + }, + wantErr: "", + }, + { + name: "port below range (1024)", + cfg: &config.Config{ + Port: 1024, + Source: "mongodb://source:27017", + Target: "mongodb://target:27017", + }, + wantErr: "port value is outside the supported range", + }, + { + name: "port above range (65536)", + cfg: &config.Config{ + Port: 65536, + Source: "mongodb://source:27017", + Target: "mongodb://target:27017", + }, + wantErr: "port value is outside the supported range", + }, + { + name: "source empty", + cfg: &config.Config{ + Port: 8080, + Source: "", + Target: "mongodb://target:27017", + }, + wantErr: "source URI is empty", + }, + { + name: "target empty", + cfg: &config.Config{ + Port: 8080, + Source: "mongodb://source:27017", + Target: "", + }, + wantErr: "target URI is empty", + }, + { + name: "both source and target empty", + cfg: &config.Config{ + Port: 8080, + Source: "", + Target: "", + }, + wantErr: "source URI and target URI are empty", + }, + { + name: "source equals target", + cfg: &config.Config{ + Port: 8080, + Source: "mongodb://same:27017", + Target: "mongodb://same:27017", + }, + wantErr: "source URI and target URI are identical", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := config.Validate(tt.cfg) + + if tt.wantErr == "" { + assert.NoError(t, err) + } else { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + } + }) + } +} + +func TestValidate_PortBoundaries(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + port int + wantErr bool + }{ + // Invalid ports (at or below 1024) + {"port 1 - invalid", 1, true}, + {"port 100 - invalid", 100, true}, + {"port 1023 - invalid", 1023, true}, + {"port 1024 - invalid (boundary)", 1024, true}, + + // Valid ports (above 1024 and up to 65535) + {"port 1025 - valid (lower boundary)", 1025, false}, + {"port 2242 - valid (default)", 2242, false}, + {"port 8080 - valid (common)", 8080, false}, + {"port 27017 - valid (MongoDB default)", 27017, false}, + {"port 65535 - valid (upper boundary)", 65535, false}, + + // Invalid ports (above 65535) + {"port 65536 - invalid (above max)", 65536, true}, + {"port 70000 - invalid", 70000, true}, + {"port 100000 - invalid", 100000, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + cfg := &config.Config{ + Port: tt.port, + Source: "mongodb://source:27017", + Target: "mongodb://target:27017", + } + + err := config.Validate(cfg) + + if tt.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), "port value is outside") + } else { + require.NoError(t, err) + } + }) + } +} + +func TestValidateCloneSegmentSize(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + sizeBytes uint64 + wantErr string + }{ + { + name: "zero (auto) - valid", + sizeBytes: 0, + wantErr: "", + }, + { + name: "at minimum boundary - valid", + sizeBytes: config.MinCloneSegmentSizeBytes, + wantErr: "", + }, + { + name: "above minimum - valid", + sizeBytes: config.MinCloneSegmentSizeBytes + 1, + wantErr: "", + }, + { + name: "at maximum boundary - valid", + sizeBytes: config.MaxCloneSegmentSizeBytes, + wantErr: "", + }, + { + name: "below minimum", + sizeBytes: config.MinCloneSegmentSizeBytes - 1, + wantErr: "cloneSegmentSize must be at least", + }, + { + name: "above maximum", + sizeBytes: config.MaxCloneSegmentSizeBytes + 1, + wantErr: "cloneSegmentSize must be at most", + }, + { + name: "1 byte (below minimum)", + sizeBytes: 1, + wantErr: "cloneSegmentSize must be at least", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := config.ValidateCloneSegmentSize(tt.sizeBytes) + + if tt.wantErr == "" { + assert.NoError(t, err) + } else { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + } + }) + } +} + +func TestValidateCloneReadBatchSize(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + sizeBytes uint64 + wantErr string + }{ + { + name: "zero (auto) - valid", + sizeBytes: 0, + wantErr: "", + }, + { + name: "at minimum boundary - valid", + sizeBytes: uint64(config.MinCloneReadBatchSizeBytes), + wantErr: "", + }, + { + name: "above minimum - valid", + sizeBytes: uint64(config.MinCloneReadBatchSizeBytes) + 1, + wantErr: "", + }, + { + name: "at maximum boundary - valid", + sizeBytes: uint64(config.MaxCloneReadBatchSizeBytes), + wantErr: "", + }, + { + name: "below minimum", + sizeBytes: uint64(config.MinCloneReadBatchSizeBytes) - 1, + wantErr: "cloneReadBatchSize must be at least", + }, + { + name: "above maximum", + sizeBytes: uint64(config.MaxCloneReadBatchSizeBytes) + 1, + wantErr: "cloneReadBatchSize must be at most", + }, + { + name: "1 byte (below minimum)", + sizeBytes: 1, + wantErr: "cloneReadBatchSize must be at least", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := config.ValidateCloneReadBatchSize(tt.sizeBytes) + + if tt.wantErr == "" { + assert.NoError(t, err) + } else { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + } + }) + } +} diff --git a/config/values.go b/config/values.go deleted file mode 100644 index 46d4280..0000000 --- a/config/values.go +++ /dev/null @@ -1,100 +0,0 @@ -package config - -import ( - "math" - "os" - "slices" - "strconv" - "strings" - "time" - - "github.com/dustin/go-humanize" -) - -// UseCollectionBulkWrite determines whether to use the Collection Bulk Write API -// instead of the Client Bulk Write API (introduced in MongoDB v8.0). -// Enabled when the PCSM_USE_COLLECTION_BULK_WRITE environment variable is set to "1". -func UseCollectionBulkWrite() bool { - return os.Getenv("PCSM_USE_COLLECTION_BULK_WRITE") == "1" -} - -// CloneNumParallelCollections returns the number of collections cloned in parallel -// during the clone process. Default is 0. -func CloneNumParallelCollections() int { - numColl, _ := strconv.ParseInt(os.Getenv("PCSM_CLONE_NUM_PARALLEL_COLLECTIONS"), 10, 32) - - return int(numColl) -} - -// CloneNumReadWorkers returns the number of read workers used during the clone. Default is 0. -// Note: Workers are shared across all collections. -func CloneNumReadWorkers() int { - numReadWorker, _ := strconv.ParseInt(os.Getenv("PCSM_CLONE_NUM_READ_WORKERS"), 10, 32) - - return int(numReadWorker) -} - -// CloneNumInsertWorkers returns the number of insert workers used during the clone. Default is 0. -// Note: Workers are shared across all collections. -func CloneNumInsertWorkers() int { - numInsertWorker, _ := strconv.ParseInt(os.Getenv("PCSM_CLONE_NUM_INSERT_WORKERS"), 10, 32) - - return int(numInsertWorker) -} - -// CloneSegmentSizeBytes returns the segment size in bytes used during the clone. -// A segment is a range within a collection (by _id) that enables concurrent read/insert -// operations by splitting the collection into multiple parallelizable units. -// Zero or less enables auto size (per each collection). Default is [AutoCloneSegmentSize]. -func CloneSegmentSizeBytes() int64 { - segmentSizeBytes, _ := humanize.ParseBytes(os.Getenv("PCSM_CLONE_SEGMENT_SIZE")) - if segmentSizeBytes == 0 { - return AutoCloneSegmentSize - } - - return int64(min(segmentSizeBytes, math.MaxInt64)) //nolint:gosec -} - -// CloneReadBatchSizeBytes returns the read batch size in bytes used during the clone. Default is 0. -func CloneReadBatchSizeBytes() int32 { - batchSizeBytes, _ := humanize.ParseBytes(os.Getenv("PCSM_CLONE_READ_BATCH_SIZE")) - - return int32(min(batchSizeBytes, math.MaxInt32)) //nolint:gosec -} - -// UseTargetClientCompressors returns a list of enabled compressors (from "zstd", "zlib", "snappy") -// for the target MongoDB client connection, as specified by the comma-separated environment -// variable PCSM_DEV_TARGET_CLIENT_COMPRESSORS. If unset or empty, returns nil. -func UseTargetClientCompressors() []string { - s := strings.TrimSpace(os.Getenv("PCSM_DEV_TARGET_CLIENT_COMPRESSORS")) - if s == "" { - return nil - } - - allowCompressors := []string{"zstd", "zlib", "snappy"} - - rv := make([]string, 0, min(len(s), len(allowCompressors))) - for a := range strings.SplitSeq(s, ",") { - a = strings.TrimSpace(a) - if slices.Contains(allowCompressors, a) && !slices.Contains(rv, a) { - rv = append(rv, a) - } - } - - return rv -} - -// OperationMongoDBCliTimeout returns the effective timeout for MongoDB client operations. -// If the environment variable `PCSM_MONGODB_CLI_OPERATION_TIMEOUT` is set, it must be a valid -// time duration string (e.g., "30s", "2m", "1h"). Otherwise, the -// DefaultMongoDBCliOperationTimeout is used. -func OperationMongoDBCliTimeout() time.Duration { - if v := strings.TrimSpace(os.Getenv("PCSM_MONGODB_CLI_OPERATION_TIMEOUT")); v != "" { - d, err := time.ParseDuration(v) - if err == nil && d > 0 { - return d - } - } - - return DefaultMongoDBCliOperationTimeout -} diff --git a/go.mod b/go.mod index 9e781f0..df4b585 100644 --- a/go.mod +++ b/go.mod @@ -4,37 +4,46 @@ go 1.25.0 require ( github.com/dustin/go-humanize v1.0.1 + github.com/go-viper/mapstructure/v2 v2.4.0 github.com/prometheus/client_golang v1.22.0 github.com/rs/zerolog v1.34.0 github.com/spf13/cobra v1.9.1 - github.com/spf13/pflag v1.0.6 - github.com/stretchr/testify v1.10.0 + github.com/spf13/viper v1.21.0 + github.com/stretchr/testify v1.11.1 go.mongodb.org/mongo-driver/v2 v2.2.1 - golang.org/x/sync v0.18.0 + golang.org/x/sync v0.19.0 ) require ( github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/golang/snappy v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/klauspost/compress v1.18.0 // indirect - github.com/kr/text v0.2.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.19 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/common v0.62.0 // indirect github.com/prometheus/procfs v0.15.1 // indirect + github.com/sagikazarmark/locafero v0.11.0 // indirect + github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect + github.com/spf13/afero v1.15.0 // indirect + github.com/spf13/cast v1.10.0 // indirect + github.com/spf13/pflag v1.0.10 // indirect + github.com/subosito/gotenv v1.6.0 // indirect github.com/xdg-go/pbkdf2 v1.0.0 // indirect github.com/xdg-go/scram v1.1.2 // indirect github.com/xdg-go/stringprep v1.0.4 // indirect github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect - golang.org/x/crypto v0.45.0 // indirect - golang.org/x/sys v0.38.0 // indirect - golang.org/x/text v0.31.0 // indirect + go.yaml.in/yaml/v3 v3.0.4 // indirect + golang.org/x/crypto v0.46.0 // indirect + golang.org/x/sys v0.39.0 // indirect + golang.org/x/text v0.32.0 // indirect google.golang.org/protobuf v1.36.5 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 42f8a11..e2b8ab4 100644 --- a/go.sum +++ b/go.sum @@ -4,11 +4,16 @@ github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UF github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= +github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= +github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs= github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= @@ -31,6 +36,8 @@ github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APP github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= +github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -48,12 +55,25 @@ github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc= +github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik= +github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw= +github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U= +github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= +github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg= +github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= +github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo= github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0= -github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o= github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= -github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= -github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= +github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU= +github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY= @@ -65,18 +85,20 @@ github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfS github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.mongodb.org/mongo-driver/v2 v2.2.1 h1:w5xra3yyu/sGrziMzK1D0cRRaH/b7lWCSsoN6+WV6AM= go.mongodb.org/mongo-driver/v2 v2.2.1/go.mod h1:qQkDMhCGWl3FN509DfdPd4GRBLU/41zqF/k8eTRceps= +go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= -golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= +golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= +golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= -golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -85,16 +107,16 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= -golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= +golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= -golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= -golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= +golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= +golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= diff --git a/main.go b/main.go index 13d053c..47e4722 100644 --- a/main.go +++ b/main.go @@ -10,7 +10,6 @@ import ( "os" "os/signal" "runtime" - "strconv" "strings" "time" @@ -19,7 +18,6 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/rs/zerolog" "github.com/spf13/cobra" - "github.com/spf13/pflag" "go.mongodb.org/mongo-driver/v2/mongo" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/connstring" @@ -34,7 +32,6 @@ import ( // Constants for server configuration. const ( - DefaultServerPort = 2242 ServerReadTimeout = 30 * time.Second ServerReadHeaderTimeout = 3 * time.Second MaxRequestSize = humanize.MiByte @@ -53,366 +50,358 @@ func buildVersion() string { return Version + " " + GitCommit + " " + BuildTime } -//nolint:gochecknoglobals -var rootCmd = &cobra.Command{ - Use: "pcsm", - Short: "Percona ClusterSync for MongoDB replication tool", +func main() { + rootCmd := newRootCmd() - SilenceUsage: true, + err := rootCmd.Execute() + if err != nil { + zerolog.Ctx(context.Background()).Fatal().Err(err).Msg("") + } +} - PersistentPreRun: func(cmd *cobra.Command, _ []string) { - logLevelFlag, _ := cmd.PersistentFlags().GetString("log-level") - logJSON, _ := cmd.PersistentFlags().GetBool("log-json") - logNoColor, _ := cmd.PersistentFlags().GetBool("no-color") +func newRootCmd() *cobra.Command { + cfg := &config.Config{} - logLevel, err := zerolog.ParseLevel(logLevelFlag) - if err != nil { - log.InitGlobals(0, logJSON, true).Fatal().Msg("Unknown log level") - } + rootCmd := &cobra.Command{ + Use: "pcsm", + Short: "Percona ClusterSync for MongoDB replication tool", - lg := log.InitGlobals(logLevel, logJSON, logNoColor) - ctx := lg.WithContext(context.Background()) - cmd.SetContext(ctx) - }, + SilenceUsage: true, - RunE: func(cmd *cobra.Command, _ []string) error { - // Check if this is the root command being executed without a subcommand - if cmd.CalledAs() != "pcsm" || cmd.ArgsLenAtDash() != -1 { - return nil - } + PersistentPreRunE: func(cmd *cobra.Command, _ []string) error { + err := config.Load(cmd, cfg) + if err != nil { + return errors.Wrap(err, "load config") + } - port, err := getPort(cmd.Flags()) - if err != nil { - return err - } + logLevel, err := zerolog.ParseLevel(cfg.Log.Level) + if err != nil { + logLevel = zerolog.InfoLevel + } - sourceURI, _ := cmd.Flags().GetString("source") - if sourceURI == "" { - sourceURI = os.Getenv("PCSM_SOURCE_URI") - } - if sourceURI == "" { - return errors.New("required flag --source not set") - } + lg := log.InitGlobals(logLevel, cfg.Log.JSON, cfg.Log.NoColor) + ctx := lg.WithContext(context.Background()) + cmd.SetContext(ctx) - targetURI, _ := cmd.Flags().GetString("target") - if targetURI == "" { - targetURI = os.Getenv("PCSM_TARGET_URI") - } - if targetURI == "" { - return errors.New("required flag --target not set") - } + config.WarnDeprecatedEnvVars(ctx) - if ok, _ := cmd.Flags().GetBool("reset-state"); ok { - err := resetState(cmd.Context(), targetURI) - if err != nil { - return err - } + return nil + }, - log.New("cli").Info("State has been reset") - } + RunE: func(cmd *cobra.Command, _ []string) error { + // Check if this is the root command being executed without a subcommand + if cmd.CalledAs() != "pcsm" || cmd.ArgsLenAtDash() != -1 { + return nil + } - start, _ := cmd.Flags().GetBool("start") - pause, _ := cmd.Flags().GetBool("pause-on-initial-sync") + err := config.Validate(cfg) + if err != nil { + return errors.Wrap(err, "validate config") + } - log.Ctx(cmd.Context()).Info("Percona ClusterSync for MongoDB " + buildVersion()) + if cfg.ResetState { + err := resetState(cmd.Context(), cfg) + if err != nil { + return err + } - return runServer(cmd.Context(), serverOptions{ - port: port, - sourceURI: sourceURI, - targetURI: targetURI, - start: start, - pause: pause, - }) - }, -} + log.New("cli").Info("State has been reset") + } -//nolint:gochecknoglobals -var versionCmd = &cobra.Command{ - Use: "version", - Short: "Print the version", - Run: func(cmd *cobra.Command, _ []string) { - info := fmt.Sprintf("Version: %s\nPlatform: %s\nGitCommit: "+ - "%s\nGitBranch: %s\nBuildTime: %s\nGoVersion: %s", - Version, - Platform, - GitCommit, - GitBranch, - BuildTime, - runtime.Version(), - ) - - cmd.Println(info) - }, -} + log.Ctx(cmd.Context()).Info("Percona ClusterSync for MongoDB " + buildVersion()) -//nolint:gochecknoglobals -var statusCmd = &cobra.Command{ - Use: "status", - Short: "Get the status of the replication process", - RunE: func(cmd *cobra.Command, _ []string) error { - port, err := getPort(cmd.Flags()) - if err != nil { - return err - } + return runServer(cmd.Context(), cfg) + }, + } - return NewClient(port).Status(cmd.Context()) - }, -} + // Persistent flags (available to all subcommands) + rootCmd.PersistentFlags().String("log-level", "info", "Log level") + rootCmd.PersistentFlags().Bool("log-json", false, "Output log in JSON format") + rootCmd.PersistentFlags().Bool("log-no-color", false, "Disable log color") -//nolint:gochecknoglobals -var startCmd = &cobra.Command{ - Use: "start", - Short: "Start Cluster Replication", - RunE: func(cmd *cobra.Command, _ []string) error { - port, err := getPort(cmd.Flags()) - if err != nil { - return err - } + rootCmd.PersistentFlags().Bool("no-color", false, "") + rootCmd.PersistentFlags().MarkDeprecated("no-color", "use --log-no-color instead") //nolint:errcheck - pauseOnInitialSync, _ := cmd.Flags().GetBool("pause-on-initial-sync") - includeNamespaces, _ := cmd.Flags().GetStringSlice("include-namespaces") - excludeNamespaces, _ := cmd.Flags().GetStringSlice("exclude-namespaces") + rootCmd.PersistentFlags().Int("port", config.DefaultServerPort, "Port number") - startOptions := startRequest{ - PauseOnInitialSync: pauseOnInitialSync, - IncludeNamespaces: includeNamespaces, - ExcludeNamespaces: excludeNamespaces, - } + // MongoDB client timeout (visible: commonly needed for debugging) + rootCmd.PersistentFlags().String("mongodb-operation-timeout", config.DefaultMongoDBOperationTimeout.String(), + "Timeout for MongoDB operations (e.g., 30s, 5m)") - return NewClient(port).Start(cmd.Context(), startOptions) - }, -} + // Root command specific flags + rootCmd.Flags().String("source", "", "MongoDB connection string for the source") + rootCmd.Flags().String("target", "", "MongoDB connection string for the target") + rootCmd.Flags().Bool("start", false, "") + rootCmd.Flags().MarkHidden("start") //nolint:errcheck -//nolint:gochecknoglobals -var finalizeCmd = &cobra.Command{ - Use: "finalize", - Short: "Finalize Cluster Replication", - RunE: func(cmd *cobra.Command, _ []string) error { - port, err := getPort(cmd.Flags()) - if err != nil { - return err - } + rootCmd.Flags().Bool("reset-state", false, "") + rootCmd.Flags().MarkHidden("reset-state") //nolint:errcheck - ignoreHistoryLost, _ := cmd.Flags().GetBool("ignore-history-lost") + rootCmd.Flags().Bool("pause-on-initial-sync", false, "") + rootCmd.Flags().MarkHidden("pause-on-initial-sync") //nolint:errcheck - finalizeOptions := finalizeRequest{ - IgnoreHistoryLost: ignoreHistoryLost, - } + rootCmd.AddCommand( + newVersionCmd(), + newStatusCmd(cfg), + newStartCmd(cfg), + newFinalizeCmd(cfg), + newPauseCmd(cfg), + newResumeCmd(cfg), + newResetCmd(cfg), + ) - return NewClient(port).Finalize(cmd.Context(), finalizeOptions) - }, + return rootCmd } -//nolint:gochecknoglobals -var pauseCmd = &cobra.Command{ - Use: "pause", - Short: "Pause Cluster Replication", - RunE: func(cmd *cobra.Command, _ []string) error { - port, err := getPort(cmd.Flags()) - if err != nil { - return err - } +func newVersionCmd() *cobra.Command { + return &cobra.Command{ + Use: "version", + Short: "Print the version", + Run: func(cmd *cobra.Command, _ []string) { + info := fmt.Sprintf("Version: %s\nPlatform: %s\nGitCommit: "+ + "%s\nGitBranch: %s\nBuildTime: %s\nGoVersion: %s", + Version, + Platform, + GitCommit, + GitBranch, + BuildTime, + runtime.Version(), + ) + + cmd.Println(info) + }, + } +} - return NewClient(port).Pause(cmd.Context()) - }, +func newStatusCmd(cfg *config.Config) *cobra.Command { + return &cobra.Command{ + Use: "status", + Short: "Get the status of the replication process", + RunE: func(cmd *cobra.Command, _ []string) error { + return NewClient(cfg.Port).Status(cmd.Context()) + }, + } } -//nolint:gochecknoglobals -var resumeCmd = &cobra.Command{ - Use: "resume", - Short: "Resume Cluster Replication", - RunE: func(cmd *cobra.Command, _ []string) error { - port, err := getPort(cmd.Flags()) - if err != nil { - return err - } +func newStartCmd(cfg *config.Config) *cobra.Command { + cmd := &cobra.Command{ + Use: "start", + Short: "Start Cluster Replication", + RunE: func(cmd *cobra.Command, _ []string) error { + pauseOnInitialSync, _ := cmd.Flags().GetBool("pause-on-initial-sync") + includeNamespaces, _ := cmd.Flags().GetStringSlice("include-namespaces") + excludeNamespaces, _ := cmd.Flags().GetStringSlice("exclude-namespaces") + + startOptions := startRequest{ + PauseOnInitialSync: pauseOnInitialSync, + IncludeNamespaces: includeNamespaces, + ExcludeNamespaces: excludeNamespaces, + } - fromFailure, _ := cmd.Flags().GetBool("from-failure") + if cfg.Clone.NumParallelCollections != 0 { + v := cfg.Clone.NumParallelCollections + startOptions.CloneNumParallelCollections = &v + } + if cfg.Clone.NumReadWorkers != 0 { + v := cfg.Clone.NumReadWorkers + startOptions.CloneNumReadWorkers = &v + } + if cfg.Clone.NumInsertWorkers != 0 { + v := cfg.Clone.NumInsertWorkers + startOptions.CloneNumInsertWorkers = &v + } + if cfg.Clone.SegmentSize != "" { + v := cfg.Clone.SegmentSize + startOptions.CloneSegmentSize = &v + } + if cfg.Clone.ReadBatchSize != "" { + v := cfg.Clone.ReadBatchSize + startOptions.CloneReadBatchSize = &v + } - resumeOptions := resumeRequest{ - FromFailure: fromFailure, - } + return NewClient(cfg.Port).Start(cmd.Context(), startOptions) + }, + } - return NewClient(port).Resume(cmd.Context(), resumeOptions) - }, -} + cmd.Flags().Bool("pause-on-initial-sync", false, "") + cmd.Flags().MarkHidden("pause-on-initial-sync") //nolint:errcheck -//nolint:gochecknoglobals -var resetCmd = &cobra.Command{ - Use: "reset", - Short: "Reset PCSM state (heartbeat and recovery data)", - RunE: func(cmd *cobra.Command, _ []string) error { - targetURI, _ := cmd.Flags().GetString("target") - if targetURI == "" { - targetURI = os.Getenv("PCSM_TARGET_URI") - } - if targetURI == "" { - return errors.New("required flag --target not set") - } + cmd.Flags().StringSlice("include-namespaces", nil, + "Namespaces to include in the replication (e.g. db1.collection1,db2.collection2)") + cmd.Flags().StringSlice("exclude-namespaces", nil, + "Namespaces to exclude from the replication (e.g. db3.collection3,db4.*)") - err := resetState(cmd.Context(), targetURI) - if err != nil { - return err - } + cmd.Flags().Int("clone-num-parallel-collections", 0, + "Number of collections to clone in parallel (0 = auto)") + cmd.Flags().Int("clone-num-read-workers", 0, + "Number of read workers during clone (0 = auto)") + cmd.Flags().Int("clone-num-insert-workers", 0, + "Number of insert workers during clone (0 = auto)") + cmd.Flags().String("clone-segment-size", "", "") + cmd.Flags().MarkHidden("clone-segment-size") //nolint:errcheck - log.New("cli").Info("OK: reset all") + cmd.Flags().String("clone-read-batch-size", "", "") + cmd.Flags().MarkHidden("clone-read-batch-size") //nolint:errcheck - return nil - }, + return cmd } -//nolint:gochecknoglobals -var resetRecoveryCmd = &cobra.Command{ - Use: "recovery", - Hidden: true, - Short: "Reset recovery state", - RunE: func(cmd *cobra.Command, _ []string) error { - targetURI, _ := cmd.InheritedFlags().GetString("target") - if targetURI == "" { - targetURI = os.Getenv("PCSM_TARGET_URI") - } - if targetURI == "" { - return errors.New("required flag --target not set") - } - - ctx := cmd.Context() - - target, err := topo.Connect(ctx, targetURI) - if err != nil { - return errors.Wrap(err, "connect") - } +func newFinalizeCmd(cfg *config.Config) *cobra.Command { + cmd := &cobra.Command{ + Use: "finalize", + Short: "Finalize Cluster Replication", + RunE: func(cmd *cobra.Command, _ []string) error { + ignoreHistoryLost, _ := cmd.Flags().GetBool("ignore-history-lost") - defer func() { - err := util.CtxWithTimeout(ctx, config.DisconnectTimeout, target.Disconnect) - if err != nil { - log.Ctx(ctx).Warn("Disconnect: " + err.Error()) + finalizeOptions := finalizeRequest{ + IgnoreHistoryLost: ignoreHistoryLost, } - }() - err = DeleteRecoveryData(ctx, target) - if err != nil { - return err - } + return NewClient(cfg.Port).Finalize(cmd.Context(), finalizeOptions) + }, + } - log.New("cli").Info("OK: reset recovery") + cmd.Flags().Bool("ignore-history-lost", false, "") + cmd.Flags().MarkHidden("ignore-history-lost") //nolint:errcheck - return nil - }, + return cmd } -//nolint:gochecknoglobals -var resetHeartbeatCmd = &cobra.Command{ - Use: "heartbeat", - Hidden: true, - Short: "Reset heartbeat state", - RunE: func(cmd *cobra.Command, _ []string) error { - targetURI, _ := cmd.InheritedFlags().GetString("target") - if targetURI == "" { - targetURI = os.Getenv("PCSM_TARGET_URI") - } - if targetURI == "" { - return errors.New("required flag --target not set") - } +func newPauseCmd(cfg *config.Config) *cobra.Command { + return &cobra.Command{ + Use: "pause", + Short: "Pause Cluster Replication", + RunE: func(cmd *cobra.Command, _ []string) error { + return NewClient(cfg.Port).Pause(cmd.Context()) + }, + } +} - ctx := cmd.Context() +func newResumeCmd(cfg *config.Config) *cobra.Command { + cmd := &cobra.Command{ + Use: "resume", + Short: "Resume Cluster Replication", + RunE: func(cmd *cobra.Command, _ []string) error { + fromFailure, _ := cmd.Flags().GetBool("from-failure") - target, err := topo.Connect(ctx, targetURI) - if err != nil { - return errors.Wrap(err, "connect") - } + resumeOptions := resumeRequest{ + FromFailure: fromFailure, + } + + return NewClient(cfg.Port).Resume(cmd.Context(), resumeOptions) + }, + } + + cmd.Flags().Bool("from-failure", false, "Resume from failure") + + return cmd +} - defer func() { - err := util.CtxWithTimeout(ctx, config.DisconnectTimeout, target.Disconnect) +func newResetCmd(cfg *config.Config) *cobra.Command { + cmd := &cobra.Command{ + Use: "reset", + Short: "Reset PCSM state (heartbeat and recovery data)", + // Reset command has an override for the --target flag + PersistentPreRunE: func(cmd *cobra.Command, args []string) error { + err := cmd.Root().PersistentPreRunE(cmd, args) if err != nil { - log.Ctx(ctx).Warn("Disconnect: " + err.Error()) + return errors.Wrap(err, "root pre-run") } - }() - err = DeleteHeartbeat(ctx, target) - if err != nil { - return err - } + if cfg.Target == "" { + return errors.New("required flag --target not set") + } - log.New("cli").Info("OK: reset heartbeat") + return nil + }, + RunE: func(cmd *cobra.Command, _ []string) error { + err := resetState(cmd.Context(), cfg) + if err != nil { + return err + } - return nil - }, -} + log.New("cli").Info("OK: reset all") -func getPort(flags *pflag.FlagSet) (int, error) { - port, _ := flags.GetInt("port") - if flags.Changed("port") { - return port, nil + return nil + }, } - portVar := os.Getenv("PCSM_PORT") - if portVar == "" { - return port, nil - } + cmd.PersistentFlags().String("target", "", "MongoDB connection string for the target") - parsedPort, err := strconv.ParseInt(portVar, 10, 32) - if err != nil { - return 0, errors.Errorf("invalid environment variable PCSM_PORT='%s'", portVar) - } + cmd.AddCommand( + newResetRecoveryCmd(cfg), + newResetHeartbeatCmd(cfg), + ) - return int(parsedPort), nil + return cmd } -func main() { - rootCmd.PersistentFlags().String("log-level", "info", "Log level") - rootCmd.PersistentFlags().Bool("log-json", false, "Output log in JSON format") - rootCmd.PersistentFlags().Bool("no-color", false, "Disable log color") +func newResetRecoveryCmd(cfg *config.Config) *cobra.Command { + return &cobra.Command{ + Use: "recovery", + Hidden: true, + Short: "Reset recovery state", + RunE: func(cmd *cobra.Command, _ []string) error { + ctx := cmd.Context() - rootCmd.Flags().Int("port", DefaultServerPort, "Port number") - rootCmd.Flags().String("source", "", "MongoDB connection string for the source") - rootCmd.Flags().String("target", "", "MongoDB connection string for the target") - rootCmd.Flags().Bool("start", false, "Start Cluster Replication immediately") - rootCmd.Flags().Bool("reset-state", false, "Reset stored PCSM state") - rootCmd.Flags().Bool("pause-on-initial-sync", false, "Pause on Initial Sync") - rootCmd.Flags().MarkHidden("start") //nolint:errcheck - rootCmd.Flags().MarkHidden("reset-state") //nolint:errcheck - rootCmd.Flags().MarkHidden("pause-on-initial-sync") //nolint:errcheck + target, err := topo.Connect(ctx, cfg.Target, cfg) + if err != nil { + return errors.Wrap(err, "connect") + } - statusCmd.Flags().Int("port", DefaultServerPort, "Port number") + defer func() { + err := util.CtxWithTimeout(ctx, config.DisconnectTimeout, target.Disconnect) + if err != nil { + log.Ctx(ctx).Warn("Disconnect: " + err.Error()) + } + }() - startCmd.Flags().Int("port", DefaultServerPort, "Port number") - startCmd.Flags().Bool("pause-on-initial-sync", false, "Pause on Initial Sync") - startCmd.Flags().MarkHidden("pause-on-initial-sync") //nolint:errcheck - startCmd.Flags().StringSlice("include-namespaces", nil, - "Namespaces to include in the replication (e.g. db1.collection1,db2.collection2)") - startCmd.Flags().StringSlice("exclude-namespaces", nil, - "Namespaces to exclude from the replication (e.g. db3.collection3,db4.*)") + err = DeleteRecoveryData(ctx, target) + if err != nil { + return err + } - pauseCmd.Flags().Int("port", DefaultServerPort, "Port number") + log.New("cli").Info("OK: reset recovery") - resumeCmd.Flags().Int("port", DefaultServerPort, "Port number") - resumeCmd.Flags().Bool("from-failure", false, "Reuse from failure") + return nil + }, + } +} - finalizeCmd.Flags().Int("port", DefaultServerPort, "Port number") - finalizeCmd.Flags().Bool("ignore-history-lost", false, "Ignore history lost error") - finalizeCmd.Flags().MarkHidden("ignore-history-lost") //nolint:errcheck +func newResetHeartbeatCmd(cfg *config.Config) *cobra.Command { + return &cobra.Command{ + Use: "heartbeat", + Hidden: true, + Short: "Reset heartbeat state", + RunE: func(cmd *cobra.Command, _ []string) error { + ctx := cmd.Context() - resetCmd.Flags().String("target", "", "MongoDB connection string for the target") + target, err := topo.Connect(ctx, cfg.Target, cfg) + if err != nil { + return errors.Wrap(err, "connect") + } - resetCmd.AddCommand(resetRecoveryCmd, resetHeartbeatCmd) - rootCmd.AddCommand( - versionCmd, - statusCmd, - startCmd, - finalizeCmd, - pauseCmd, - resumeCmd, - resetCmd, - ) + defer func() { + err := util.CtxWithTimeout(ctx, config.DisconnectTimeout, target.Disconnect) + if err != nil { + log.Ctx(ctx).Warn("Disconnect: " + err.Error()) + } + }() - err := rootCmd.Execute() - if err != nil { - zerolog.Ctx(context.Background()).Fatal().Err(err).Msg("") + err = DeleteHeartbeat(ctx, target) + if err != nil { + return err + } + + log.New("cli").Info("OK: reset heartbeat") + + return nil + }, } } -func resetState(ctx context.Context, targetURI string) error { - target, err := topo.Connect(ctx, targetURI) +func resetState(ctx context.Context, cfg *config.Config) error { + target, err := topo.Connect(ctx, cfg.Target, cfg) if err != nil { return errors.Wrap(err, "connect") } @@ -437,51 +426,19 @@ func resetState(ctx context.Context, targetURI string) error { return nil } -type serverOptions struct { - port int - sourceURI string - targetURI string - start bool - pause bool -} - -func (s serverOptions) validate() error { - if s.port <= 1024 || s.port > 65535 { - return errors.New("port value is outside the supported range [1024 - 65535]") - } - - switch { - case s.sourceURI == "" && s.targetURI == "": - return errors.New("source URI and target URI are empty") - case s.sourceURI == "": - return errors.New("source URI is empty") - case s.targetURI == "": - return errors.New("target URI is empty") - case s.sourceURI == s.targetURI: - return errors.New("source URI and target URI are identical") - } - - return nil -} - // runServer starts the HTTP server with the provided configuration. -func runServer(ctx context.Context, options serverOptions) error { - err := options.validate() - if err != nil { - return errors.Wrap(err, "validate options") - } - +func runServer(_ context.Context, cfg *config.Config) error { ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, os.Kill) defer stop() - srv, err := createServer(ctx, options.sourceURI, options.targetURI) + srv, err := createServer(ctx, cfg) if err != nil { return errors.Wrap(err, "new server") } - if options.start && srv.pcsm.Status(ctx).State == pcsm.StateIdle { + if cfg.Start && srv.pcsm.Status(ctx).State == pcsm.StateIdle { err = srv.pcsm.Start(ctx, &pcsm.StartOptions{ - PauseOnInitialSync: options.pause, + PauseOnInitialSync: cfg.PauseOnInitialSync, }) if err != nil { log.New("cli").Error(err, "Failed to start Cluster Replication") @@ -499,7 +456,12 @@ func runServer(ctx context.Context, options serverOptions) error { os.Exit(0) }() - addr := fmt.Sprintf("localhost:%d", options.port) + port := cfg.Port + if port == 0 { + port = config.DefaultServerPort + } + + addr := fmt.Sprintf("localhost:%d", port) httpServer := http.Server{ Addr: addr, Handler: srv.Handler(), @@ -513,8 +475,9 @@ func runServer(ctx context.Context, options serverOptions) error { return httpServer.ListenAndServe() //nolint:wrapcheck } -// server represents the replication server. type server struct { + // cfg holds the configuration. + cfg *config.Config // sourceCluster is the MongoDB client for the source cluster. sourceCluster *mongo.Client // targetCluster is the MongoDB client for the target cluster. @@ -529,10 +492,10 @@ type server struct { } // createServer creates a new server with the given options. -func createServer(ctx context.Context, sourceURI, targetURI string) (*server, error) { +func createServer(ctx context.Context, cfg *config.Config) (*server, error) { lg := log.Ctx(ctx) - source, err := topo.Connect(ctx, sourceURI) + source, err := topo.Connect(ctx, cfg.Source, cfg) if err != nil { return nil, errors.Wrap(err, "connect to source cluster") } @@ -553,13 +516,11 @@ func createServer(ctx context.Context, sourceURI, targetURI string) (*server, er return nil, errors.Wrap(err, "source version") } - cs, _ := connstring.Parse(sourceURI) + cs, _ := connstring.Parse(cfg.Source) lg.Infof("Connected to source cluster [%s]: %s://%s", sourceVersion.FullString(), cs.Scheme, strings.Join(cs.Hosts, ",")) - target, err := topo.ConnectWithOptions(ctx, targetURI, &topo.ConnectOptions{ - Compressors: config.UseTargetClientCompressors(), - }) + target, err := topo.Connect(ctx, cfg.Target, cfg) if err != nil { return nil, errors.Wrap(err, "connect to target cluster") } @@ -580,7 +541,7 @@ func createServer(ctx context.Context, sourceURI, targetURI string) (*server, er return nil, errors.Wrap(err, "target version") } - cs, _ = connstring.Parse(targetURI) + cs, _ = connstring.Parse(cfg.Target) lg.Infof("Connected to target cluster [%s]: %s://%s", targetVersion.FullString(), cs.Scheme, strings.Join(cs.Hosts, ",")) @@ -611,6 +572,7 @@ func createServer(ctx context.Context, sourceURI, targetURI string) (*server, er go RunCheckpointing(ctx, target, pcs) s := &server{ + cfg: cfg, sourceCluster: source, targetCluster: target, pcsm: pcs, @@ -634,12 +596,12 @@ func (s *server) Close(ctx context.Context) error { func (s *server) Handler() http.Handler { mux := http.NewServeMux() - mux.HandleFunc("/status", s.handleStatus) - mux.HandleFunc("/start", s.handleStart) - mux.HandleFunc("/finalize", s.handleFinalize) - mux.HandleFunc("/pause", s.handlePause) - mux.HandleFunc("/resume", s.handleResume) - mux.Handle("/metrics", s.handleMetrics()) + mux.HandleFunc("/status", s.HandleStatus) + mux.HandleFunc("/start", s.HandleStart) + mux.HandleFunc("/finalize", s.HandleFinalize) + mux.HandleFunc("/pause", s.HandlePause) + mux.HandleFunc("/resume", s.HandleResume) + mux.Handle("/metrics", s.HandleMetrics()) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/metrics" { @@ -651,8 +613,8 @@ func (s *server) Handler() http.Handler { }) } -// handleStatus handles the /status endpoint. -func (s *server) handleStatus(w http.ResponseWriter, r *http.Request) { +// HandleStatus handles the /status endpoint. +func (s *server) HandleStatus(w http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithTimeout(r.Context(), ServerResponseTimeout) defer cancel() @@ -735,8 +697,66 @@ func (s *server) handleStatus(w http.ResponseWriter, r *http.Request) { writeResponse(w, res) } -// handleStart handles the /start endpoint. -func (s *server) handleStart(w http.ResponseWriter, r *http.Request) { +// resolveStartOptions resolves the start options from the HTTP request and config. +// Clone tuning options use config (env var) as defaults, CLI/HTTP params override. +func resolveStartOptions(cfg *config.Config, params startRequest) (*pcsm.StartOptions, error) { + options := &pcsm.StartOptions{ + PauseOnInitialSync: params.PauseOnInitialSync, + IncludeNamespaces: params.IncludeNamespaces, + ExcludeNamespaces: params.ExcludeNamespaces, + Repl: pcsm.ReplOptions{ + UseCollectionBulkWrite: cfg.UseCollectionBulkWrite, + }, + Clone: pcsm.CloneOptions{ + Parallelism: cfg.Clone.NumParallelCollections, + ReadWorkers: cfg.Clone.NumReadWorkers, + InsertWorkers: cfg.Clone.NumInsertWorkers, + }, + } + + if params.CloneNumParallelCollections != nil { + options.Clone.Parallelism = *params.CloneNumParallelCollections + } + + if params.CloneNumReadWorkers != nil { + options.Clone.ReadWorkers = *params.CloneNumReadWorkers + } + + if params.CloneNumInsertWorkers != nil { + options.Clone.InsertWorkers = *params.CloneNumInsertWorkers + } + + segmentSizeStr := cfg.Clone.SegmentSize + if params.CloneSegmentSize != nil { + segmentSizeStr = *params.CloneSegmentSize + } + + if segmentSizeStr != "" { + segmentSize, err := config.ParseAndValidateCloneSegmentSize(segmentSizeStr) + if err != nil { + return nil, errors.Wrap(err, "invalid clone segment size") + } + options.Clone.SegmentSizeBytes = segmentSize + } + + readBatchSizeStr := cfg.Clone.ReadBatchSize + if params.CloneReadBatchSize != nil { + readBatchSizeStr = *params.CloneReadBatchSize + } + + if readBatchSizeStr != "" { + batchSize, err := config.ParseAndValidateCloneReadBatchSize(readBatchSizeStr) + if err != nil { + return nil, errors.Wrap(err, "invalid clone read batch size") + } + options.Clone.ReadBatchSizeBytes = batchSize + } + + return options, nil +} + +// HandleStart handles the /start endpoint. +func (s *server) HandleStart(w http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithTimeout(r.Context(), ServerResponseTimeout) defer cancel() @@ -778,13 +798,14 @@ func (s *server) handleStart(w http.ResponseWriter, r *http.Request) { } } - options := &pcsm.StartOptions{ - PauseOnInitialSync: params.PauseOnInitialSync, - IncludeNamespaces: params.IncludeNamespaces, - ExcludeNamespaces: params.ExcludeNamespaces, + options, err := resolveStartOptions(s.cfg, params) + if err != nil { + writeResponse(w, startResponse{Err: err.Error()}) + + return } - err := s.pcsm.Start(ctx, options) + err = s.pcsm.Start(ctx, options) if err != nil { writeResponse(w, startResponse{Err: err.Error()}) @@ -794,8 +815,8 @@ func (s *server) handleStart(w http.ResponseWriter, r *http.Request) { writeResponse(w, startResponse{Ok: true}) } -// handleFinalize handles the /finalize endpoint. -func (s *server) handleFinalize(w http.ResponseWriter, r *http.Request) { +// HandleFinalize handles the /finalize endpoint. +func (s *server) HandleFinalize(w http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithTimeout(r.Context(), ServerResponseTimeout) defer cancel() @@ -851,8 +872,8 @@ func (s *server) handleFinalize(w http.ResponseWriter, r *http.Request) { writeResponse(w, finalizeResponse{Ok: true}) } -// handlePause handles the /pause endpoint. -func (s *server) handlePause(w http.ResponseWriter, r *http.Request) { +// HandlePause handles the /pause endpoint. +func (s *server) HandlePause(w http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithTimeout(r.Context(), ServerResponseTimeout) defer cancel() @@ -882,8 +903,8 @@ func (s *server) handlePause(w http.ResponseWriter, r *http.Request) { writeResponse(w, pauseResponse{Ok: true}) } -// handleResume handles the /resume endpoint. -func (s *server) handleResume(w http.ResponseWriter, r *http.Request) { +// HandleResume handles the /resume endpoint. +func (s *server) HandleResume(w http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithTimeout(r.Context(), ServerResponseTimeout) defer cancel() @@ -939,7 +960,7 @@ func (s *server) handleResume(w http.ResponseWriter, r *http.Request) { writeResponse(w, resumeResponse{Ok: true}) } -func (s *server) handleMetrics() http.Handler { +func (s *server) HandleMetrics() http.Handler { return promhttp.HandlerFor(s.promRegistry, promhttp.HandlerOpts{}) } @@ -962,6 +983,20 @@ type startRequest struct { IncludeNamespaces []string `json:"includeNamespaces,omitempty"` // ExcludeNamespaces are the namespaces to exclude from the replication. ExcludeNamespaces []string `json:"excludeNamespaces,omitempty"` + + // Clone tuning options (pointer types to distinguish "not set" from zero value) + // CloneNumParallelCollections is the number of collections to clone in parallel. + CloneNumParallelCollections *int `json:"cloneNumParallelCollections,omitempty"` + // CloneNumReadWorkers is the number of read workers during clone. + CloneNumReadWorkers *int `json:"cloneNumReadWorkers,omitempty"` + // CloneNumInsertWorkers is the number of insert workers during clone. + CloneNumInsertWorkers *int `json:"cloneNumInsertWorkers,omitempty"` + // CloneSegmentSize is the segment size for clone operations (e.g., "100MB", "1GiB"). + CloneSegmentSize *string `json:"cloneSegmentSize,omitempty"` + // CloneReadBatchSize is the read batch size during clone (e.g., "16MiB"). + CloneReadBatchSize *string `json:"cloneReadBatchSize,omitempty"` + + // NOTE: UseCollectionBulkWrite intentionally NOT exposed via HTTP (internal only) } // startResponse represents the response body for the /start endpoint. diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..2af82ab --- /dev/null +++ b/main_test.go @@ -0,0 +1,613 @@ +package main_test + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "os/exec" + "path/filepath" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/percona/percona-clustersync-mongodb/errors" +) + +// errCommandTimeout is returned when the command execution times out. +var errCommandTimeout = errors.New("command timed out") + +// binaryPath holds the path to the compiled pcsm binary. +// +//nolint:gochecknoglobals +var binaryPath string + +// TestMain builds the binary once before running all tests. +func TestMain(m *testing.M) { + code := runTestMain(m) + os.Exit(code) +} + +func runTestMain(m *testing.M) int { + tmpDir, err := os.MkdirTemp("", "pcsm-integration-test") + if err != nil { + fmt.Fprintf(os.Stderr, "failed to create temp dir: %v\n", err) + + return 1 + } + defer os.RemoveAll(tmpDir) + + binaryPath = filepath.Join(tmpDir, "pcsm") + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + cmd := exec.CommandContext(ctx, "go", "build", "-race", "-o", binaryPath, ".") + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + err = cmd.Run() + if err != nil { + fmt.Fprintf(os.Stderr, "failed to build binary: %v\n", err) + + return 1 + } + + // Run tests + return m.Run() +} + +// capturedRequest holds the details of an HTTP request captured by the mock server. +type capturedRequest struct { + Method string + Path string + Body []byte +} + +// mockPCSMServer creates a mock PCSM HTTP server that captures requests. +// It returns the server and a channel that receives captured requests. +func mockPCSMServer(t *testing.T, response any) (*httptest.Server, *capturedRequest, *sync.Mutex) { + t.Helper() + + var captured capturedRequest + var mu sync.Mutex + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + defer mu.Unlock() + + captured.Method = r.Method + captured.Path = r.URL.Path + + body, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("failed to read request body: %v", err) + http.Error(w, "internal error", http.StatusInternalServerError) + + return + } + captured.Body = body + + w.Header().Set("Content-Type", "application/json") + + encErr := json.NewEncoder(w).Encode(response) + if encErr != nil { + t.Errorf("failed to encode response: %v", encErr) + } + })) + + return server, &captured, &mu +} + +func extractPort(serverURL string) string { + parts := strings.Split(serverURL, ":") + if len(parts) < 3 { + return "" + } + + return parts[len(parts)-1] +} + +// runPCSM runs the pcsm binary with the given arguments and environment variables. +func runPCSM(t *testing.T, args []string, env map[string]string) (string, string, error) { + t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, binaryPath, args...) + + cmd.Env = os.Environ() + for k, v := range env { + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", k, v)) + } + + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + err := cmd.Run() + if errors.Is(ctx.Err(), context.DeadlineExceeded) { + return stdout.String(), stderr.String(), errCommandTimeout + } + + return stdout.String(), stderr.String(), err +} + +type standardResponse struct { + Ok bool `json:"ok"` +} + +type statusResponseMock struct { + Ok bool `json:"ok"` + State string `json:"state"` + LagTimeSeconds int64 `json:"lagTimeSeconds"` +} + +func TestStatusCommand(t *testing.T) { + t.Parallel() + + response := statusResponseMock{ + Ok: true, + State: "running", + LagTimeSeconds: 5, + } + + server, captured, mu := mockPCSMServer(t, response) + defer server.Close() + + port := extractPort(server.URL) + + stdout, stderr, err := runPCSM(t, []string{"--port", port, "status"}, nil) + + require.NoError(t, err, "stderr: %s", stderr) + + mu.Lock() + defer mu.Unlock() + + assert.Equal(t, http.MethodGet, captured.Method) + assert.Equal(t, "/status", captured.Path) + assert.Empty(t, captured.Body) + + assert.Contains(t, stdout, `"ok": true`) + assert.Contains(t, stdout, `"state": "running"`) +} + +func TestStartCommand(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + args []string + expectedBody map[string]any + }{ + { + name: "no flags", + args: []string{"start"}, + expectedBody: map[string]any{}, + }, + { + name: "pause-on-initial-sync", + args: []string{"start", "--pause-on-initial-sync"}, + expectedBody: map[string]any{"pauseOnInitialSync": true}, + }, + { + name: "include-namespaces single", + args: []string{"start", "--include-namespaces=db.coll"}, + expectedBody: map[string]any{"includeNamespaces": []any{"db.coll"}}, + }, + { + name: "include-namespaces multiple", + args: []string{"start", "--include-namespaces=db1.coll1,db2.coll2"}, + expectedBody: map[string]any{"includeNamespaces": []any{"db1.coll1", "db2.coll2"}}, + }, + { + name: "exclude-namespaces single", + args: []string{"start", "--exclude-namespaces=db.coll"}, + expectedBody: map[string]any{"excludeNamespaces": []any{"db.coll"}}, + }, + { + name: "exclude-namespaces multiple", + args: []string{"start", "--exclude-namespaces=db1.*,db2.coll"}, + expectedBody: map[string]any{"excludeNamespaces": []any{"db1.*", "db2.coll"}}, + }, + { + name: "all flags combined", + args: []string{ + "start", + "--pause-on-initial-sync", + "--include-namespaces=db1.coll1", + "--exclude-namespaces=db2.*", + }, + expectedBody: map[string]any{ + "pauseOnInitialSync": true, + "includeNamespaces": []any{"db1.coll1"}, + "excludeNamespaces": []any{"db2.*"}, + }, + }, + // Clone tuning flags + { + name: "clone-num-parallel-collections", + args: []string{"start", "--clone-num-parallel-collections=8"}, + expectedBody: map[string]any{"cloneNumParallelCollections": float64(8)}, + }, + { + name: "clone-num-read-workers", + args: []string{"start", "--clone-num-read-workers=16"}, + expectedBody: map[string]any{"cloneNumReadWorkers": float64(16)}, + }, + { + name: "clone-num-insert-workers", + args: []string{"start", "--clone-num-insert-workers=4"}, + expectedBody: map[string]any{"cloneNumInsertWorkers": float64(4)}, + }, + { + name: "clone-segment-size", + args: []string{"start", "--clone-segment-size=500MB"}, + expectedBody: map[string]any{"cloneSegmentSize": "500MB"}, + }, + { + name: "clone-read-batch-size", + args: []string{"start", "--clone-read-batch-size=32MiB"}, + expectedBody: map[string]any{"cloneReadBatchSize": "32MiB"}, + }, + { + name: "all clone flags combined", + args: []string{ + "start", + "--clone-num-parallel-collections=8", + "--clone-num-read-workers=16", + "--clone-num-insert-workers=4", + "--clone-segment-size=1GiB", + "--clone-read-batch-size=48MB", + }, + expectedBody: map[string]any{ + "cloneNumParallelCollections": float64(8), + "cloneNumReadWorkers": float64(16), + "cloneNumInsertWorkers": float64(4), + "cloneSegmentSize": "1GiB", + "cloneReadBatchSize": "48MB", + }, + }, + { + name: "all flags including clone options", + args: []string{ + "start", + "--pause-on-initial-sync", + "--include-namespaces=db1.coll1", + "--exclude-namespaces=db2.*", + "--clone-num-parallel-collections=4", + "--clone-segment-size=2GiB", + }, + expectedBody: map[string]any{ + "pauseOnInitialSync": true, + "includeNamespaces": []any{"db1.coll1"}, + "excludeNamespaces": []any{"db2.*"}, + "cloneNumParallelCollections": float64(4), + "cloneSegmentSize": "2GiB", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + response := standardResponse{Ok: true} + server, captured, mu := mockPCSMServer(t, response) + defer server.Close() + + port := extractPort(server.URL) + + // Prepend port flag + args := append([]string{"--port", port}, tt.args...) + + _, stderr, err := runPCSM(t, args, nil) + require.NoError(t, err, "stderr: %s", stderr) + + mu.Lock() + defer mu.Unlock() + + assert.Equal(t, http.MethodPost, captured.Method) + assert.Equal(t, "/start", captured.Path) + + var actualBody map[string]any + if len(captured.Body) > 0 { + err = json.Unmarshal(captured.Body, &actualBody) + require.NoError(t, err) + } else { + actualBody = map[string]any{} + } + + expectedJSON, _ := json.Marshal(tt.expectedBody) + actualJSON, _ := json.Marshal(actualBody) + assert.JSONEq(t, string(expectedJSON), string(actualJSON)) + }) + } +} + +func TestFinalizeCommand(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + args []string + expectedBody map[string]any + }{ + { + name: "no flags", + args: []string{"finalize"}, + expectedBody: map[string]any{}, + }, + { + name: "ignore-history-lost", + args: []string{"finalize", "--ignore-history-lost"}, + expectedBody: map[string]any{"ignoreHistoryLost": true}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + response := standardResponse{Ok: true} + server, captured, mu := mockPCSMServer(t, response) + defer server.Close() + + port := extractPort(server.URL) + args := append([]string{"--port", port}, tt.args...) + + _, stderr, err := runPCSM(t, args, nil) + require.NoError(t, err, "stderr: %s", stderr) + + mu.Lock() + defer mu.Unlock() + + assert.Equal(t, http.MethodPost, captured.Method) + assert.Equal(t, "/finalize", captured.Path) + + var actualBody map[string]any + if len(captured.Body) > 0 { + err = json.Unmarshal(captured.Body, &actualBody) + require.NoError(t, err) + } else { + actualBody = map[string]any{} + } + + expectedJSON, _ := json.Marshal(tt.expectedBody) + actualJSON, _ := json.Marshal(actualBody) + assert.JSONEq(t, string(expectedJSON), string(actualJSON)) + }) + } +} + +func TestPauseCommand(t *testing.T) { + t.Parallel() + + response := standardResponse{Ok: true} + server, captured, mu := mockPCSMServer(t, response) + defer server.Close() + + port := extractPort(server.URL) + + _, stderr, err := runPCSM(t, []string{"--port", port, "pause"}, nil) + require.NoError(t, err, "stderr: %s", stderr) + + mu.Lock() + defer mu.Unlock() + + assert.Equal(t, http.MethodPost, captured.Method) + assert.Equal(t, "/pause", captured.Path) + // Pause sends nil body, which becomes empty string + assert.Empty(t, captured.Body) +} + +func TestResumeCommand(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + args []string + expectedBody map[string]any + }{ + { + name: "no flags", + args: []string{"resume"}, + expectedBody: map[string]any{}, + }, + { + name: "from-failure", + args: []string{"resume", "--from-failure"}, + expectedBody: map[string]any{"fromFailure": true}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + response := standardResponse{Ok: true} + server, captured, mu := mockPCSMServer(t, response) + defer server.Close() + + port := extractPort(server.URL) + args := append([]string{"--port", port}, tt.args...) + + _, stderr, err := runPCSM(t, args, nil) + require.NoError(t, err, "stderr: %s", stderr) + + mu.Lock() + defer mu.Unlock() + + assert.Equal(t, http.MethodPost, captured.Method) + assert.Equal(t, "/resume", captured.Path) + + var actualBody map[string]any + if len(captured.Body) > 0 { + err = json.Unmarshal(captured.Body, &actualBody) + require.NoError(t, err) + } else { + actualBody = map[string]any{} + } + + expectedJSON, _ := json.Marshal(tt.expectedBody) + actualJSON, _ := json.Marshal(actualBody) + assert.JSONEq(t, string(expectedJSON), string(actualJSON)) + }) + } +} + +func TestPortConfiguration(t *testing.T) { + t.Parallel() + + t.Run("port via flag", func(t *testing.T) { + t.Parallel() + + response := statusResponseMock{Ok: true, State: "idle"} + server, captured, mu := mockPCSMServer(t, response) + defer server.Close() + + port := extractPort(server.URL) + + _, stderr, err := runPCSM(t, []string{"--port", port, "status"}, nil) + require.NoError(t, err, "stderr: %s", stderr) + + mu.Lock() + defer mu.Unlock() + + assert.Equal(t, "/status", captured.Path) + }) + + t.Run("port via PCSM_PORT env var", func(t *testing.T) { + t.Parallel() + + response := statusResponseMock{Ok: true, State: "idle"} + server, captured, mu := mockPCSMServer(t, response) + defer server.Close() + + port := extractPort(server.URL) + + _, stderr, err := runPCSM(t, []string{"status"}, map[string]string{ + "PCSM_PORT": port, + }) + require.NoError(t, err, "stderr: %s", stderr) + + mu.Lock() + defer mu.Unlock() + + assert.Equal(t, "/status", captured.Path) + }) + + t.Run("flag takes precedence over env var", func(t *testing.T) { + t.Parallel() + + // Create two servers - one for env var (wrong), one for flag (correct) + wrongResponse := statusResponseMock{Ok: false, State: "wrong"} + wrongServer, _, _ := mockPCSMServer(t, wrongResponse) + defer wrongServer.Close() + wrongPort := extractPort(wrongServer.URL) + + correctResponse := statusResponseMock{Ok: true, State: "correct"} + correctServer, captured, mu := mockPCSMServer(t, correctResponse) + defer correctServer.Close() + correctPort := extractPort(correctServer.URL) + + stdout, stderr, err := runPCSM(t, []string{"--port", correctPort, "status"}, map[string]string{ + "PCSM_PORT": wrongPort, + }) + require.NoError(t, err, "stderr: %s", stderr) + + mu.Lock() + defer mu.Unlock() + + // Should have hit the correct server (flag) + assert.Equal(t, "/status", captured.Path) + assert.Contains(t, stdout, `"state": "correct"`) + }) +} + +func TestStartConfigPrecedence(t *testing.T) { + t.Parallel() + + t.Run("flag takes precedence over env var for clone-num-parallel-collections", func(t *testing.T) { + t.Parallel() + + response := standardResponse{Ok: true} + server, captured, mu := mockPCSMServer(t, response) + defer server.Close() + + port := extractPort(server.URL) + + _, stderr, err := runPCSM(t, + []string{"--port", port, "start", "--clone-num-parallel-collections=8"}, + map[string]string{ + "PCSM_CLONE_NUM_PARALLEL_COLLECTIONS": "2", + }) + require.NoError(t, err, "stderr: %s", stderr) + + mu.Lock() + defer mu.Unlock() + + assert.Equal(t, http.MethodPost, captured.Method) + assert.Equal(t, "/start", captured.Path) + + var actualBody map[string]any + err = json.Unmarshal(captured.Body, &actualBody) + require.NoError(t, err) + + assert.EqualValues(t, 8, actualBody["cloneNumParallelCollections"]) + }) + + t.Run("env var is used when flag not provided", func(t *testing.T) { + t.Parallel() + + response := standardResponse{Ok: true} + server, captured, mu := mockPCSMServer(t, response) + defer server.Close() + + port := extractPort(server.URL) + + _, stderr, err := runPCSM(t, + []string{"--port", port, "start"}, + map[string]string{ + "PCSM_CLONE_NUM_PARALLEL_COLLECTIONS": "4", + }) + require.NoError(t, err, "stderr: %s", stderr) + + mu.Lock() + defer mu.Unlock() + + assert.Equal(t, http.MethodPost, captured.Method) + assert.Equal(t, "/start", captured.Path) + + var actualBody map[string]any + err = json.Unmarshal(captured.Body, &actualBody) + require.NoError(t, err) + + assert.EqualValues(t, 4, actualBody["cloneNumParallelCollections"]) + }) +} + +func TestConnectionRefused(t *testing.T) { + t.Parallel() + + _, stderr, err := runPCSM(t, []string{"--port", "59999", "status"}, nil) + + // Should fail with connection refused + require.Error(t, err) + // Error should mention connection issue + combinedOutput := stderr + assert.True(t, + strings.Contains(combinedOutput, "connection refused") || + strings.Contains(combinedOutput, "connect:") || + strings.Contains(combinedOutput, "dial"), + "expected connection error, got: %s", combinedOutput) +} diff --git a/pcsm/clone.go b/pcsm/clone.go index 9b7fc00..8f23585 100644 --- a/pcsm/clone.go +++ b/pcsm/clone.go @@ -22,12 +22,32 @@ import ( "github.com/percona/percona-clustersync-mongodb/topo" ) +// CloneOptions configures the clone behavior. +type CloneOptions struct { + // Parallelism is the number of collections to clone in parallel. + // Default: 2 (config.DefaultCloneNumParallelCollection) + Parallelism int + // ReadWorkers is the number of read workers during clone. + // Default: auto (0 = runtime.NumCPU()/4) + ReadWorkers int + // InsertWorkers is the number of insert workers during clone. + // Default: auto (0 = runtime.NumCPU()*2) + InsertWorkers int + // SegmentSizeBytes is the segment size for clone operations in bytes. + // Default: auto (0 = calculated per collection) + SegmentSizeBytes int64 + // ReadBatchSizeBytes is the read batch size during clone in bytes. + // Default: ~47.5MB (config.DefaultCloneReadBatchSizeBytes) + ReadBatchSizeBytes int32 +} + // Clone handles the cloning of data from a source MongoDB to a target MongoDB. type Clone struct { source *mongo.Client // Source MongoDB client target *mongo.Client // Target MongoDB client catalog *Catalog // Catalog for managing collections and indexes nsFilter sel.NSFilter // Namespace filter + options *CloneOptions // Clone options lock sync.Mutex err error // Error encountered during the cloning process @@ -74,12 +94,19 @@ func (cs *CloneStatus) IsFinished() bool { return !cs.FinishTime.IsZero() } -func NewClone(source, target *mongo.Client, catalog *Catalog, nsFilter sel.NSFilter) *Clone { +// NewClone creates a new Clone instance with the given options. +func NewClone( + source, target *mongo.Client, + catalog *Catalog, + nsFilter sel.NSFilter, + opts *CloneOptions, +) *Clone { return &Clone{ source: source, target: target, catalog: catalog, nsFilter: nsFilter, + options: opts, doneSig: make(chan struct{}), } } @@ -292,7 +319,7 @@ func (c *Clone) run() error { func (c *Clone) doClone(ctx context.Context, namespaces []namespaceInfo) error { cloneLogger := log.Ctx(ctx) - numParallelCollections := config.CloneNumParallelCollections() + numParallelCollections := c.options.Parallelism if numParallelCollections < 1 { numParallelCollections = config.DefaultCloneNumParallelCollection } @@ -300,10 +327,10 @@ func (c *Clone) doClone(ctx context.Context, namespaces []namespaceInfo) error { cloneLogger.Debugf("NumParallelCollections: %d", numParallelCollections) copyManager := NewCopyManager(c.source, c.target, CopyManagerOptions{ - NumReadWorkers: config.CloneNumReadWorkers(), - NumInsertWorkers: config.CloneNumInsertWorkers(), - SegmentSizeBytes: config.CloneSegmentSizeBytes(), - ReadBatchSizeBytes: config.CloneReadBatchSizeBytes(), + NumReadWorkers: c.options.ReadWorkers, + NumInsertWorkers: c.options.InsertWorkers, + SegmentSizeBytes: c.options.SegmentSizeBytes, + ReadBatchSizeBytes: c.options.ReadBatchSizeBytes, }) defer copyManager.Close() diff --git a/pcsm/copy_test.go b/pcsm/copy_test.go index b032dec..d3296ec 100644 --- a/pcsm/copy_test.go +++ b/pcsm/copy_test.go @@ -93,7 +93,7 @@ func BenchmarkRead(b *testing.B) { ctx := b.Context() ns := getNamespace() - mc, err := topo.Connect(ctx, getSourceURI()) + mc, err := topo.Connect(ctx, getSourceURI(), &config.Config{}) if err != nil { b.Fatal(err) } @@ -147,7 +147,7 @@ func BenchmarkInsert(b *testing.B) { ctx := b.Context() ns := getNamespace() - mc, err := topo.Connect(ctx, getTargetURI()) + mc, err := topo.Connect(ctx, getTargetURI(), &config.Config{}) if err != nil { b.Fatal(err) } diff --git a/pcsm/pcsm.go b/pcsm/pcsm.go index 84952fa..23f2f65 100644 --- a/pcsm/pcsm.go +++ b/pcsm/pcsm.go @@ -167,8 +167,9 @@ func (ml *PCSM) Recover(ctx context.Context, data []byte) error { nsFilter := sel.MakeFilter(cp.NSInclude, cp.NSExclude) catalog := NewCatalog(ml.target) - clone := NewClone(ml.source, ml.target, catalog, nsFilter) - repl := NewRepl(ml.source, ml.target, catalog, nsFilter) + // Use empty options for recovery (clone tuning is less relevant when resuming from checkpoint) + clone := NewClone(ml.source, ml.target, catalog, nsFilter, &CloneOptions{}) + repl := NewRepl(ml.source, ml.target, catalog, nsFilter, &ReplOptions{}) if cp.Catalog != nil { err = catalog.Recover(cp.Catalog) @@ -283,12 +284,17 @@ func (ml *PCSM) resetError() { // StartOptions represents the options for starting the PCSM. type StartOptions struct { - // PauseOnInitialSync indicates whether to finalize after the initial sync. + // PauseOnInitialSync indicates whether to pause after the initial sync completes. PauseOnInitialSync bool // IncludeNamespaces are the namespaces to include. IncludeNamespaces []string // ExcludeNamespaces are the namespaces to exclude. ExcludeNamespaces []string + + // Clone contains clone tuning options. + Clone CloneOptions + // Repl contains replication behavior options. + Repl ReplOptions } // Start starts the replication process with the given options. @@ -319,8 +325,8 @@ func (ml *PCSM) Start(_ context.Context, options *StartOptions) error { ml.nsFilter = sel.MakeFilter(ml.nsInclude, ml.nsExclude) ml.pauseOnInitialSync = options.PauseOnInitialSync ml.catalog = NewCatalog(ml.target) - ml.clone = NewClone(ml.source, ml.target, ml.catalog, ml.nsFilter) - ml.repl = NewRepl(ml.source, ml.target, ml.catalog, ml.nsFilter) + ml.clone = NewClone(ml.source, ml.target, ml.catalog, ml.nsFilter, &options.Clone) + ml.repl = NewRepl(ml.source, ml.target, ml.catalog, ml.nsFilter, &options.Repl) ml.state = StateRunning go ml.run() diff --git a/pcsm/repl.go b/pcsm/repl.go index cd4f84f..67f4a14 100644 --- a/pcsm/repl.go +++ b/pcsm/repl.go @@ -28,6 +28,13 @@ var ( const advanceTimePseudoEvent = "@tick" +// ReplOptions configures the replication behavior. +type ReplOptions struct { + // UseCollectionBulkWrite indicates whether to use collection-level bulk write + // instead of client bulk write. Default: false (use client bulk write). + UseCollectionBulkWrite bool +} + // Repl handles replication from a source MongoDB to a target MongoDB. type Repl struct { source *mongo.Client // Source MongoDB client @@ -36,6 +43,8 @@ type Repl struct { nsFilter sel.NSFilter // Namespace filter catalog *Catalog // Catalog for managing collections and indexes + options *ReplOptions // Replication options + lastReplicatedOpTime bson.Timestamp lock sync.Mutex @@ -84,12 +93,19 @@ func (rs *ReplStatus) IsPaused() bool { return !rs.PauseTime.IsZero() } -func NewRepl(source, target *mongo.Client, catalog *Catalog, nsFilter sel.NSFilter) *Repl { +// NewRepl creates a new Repl instance. +func NewRepl( + source, target *mongo.Client, + catalog *Catalog, + nsFilter sel.NSFilter, + opts *ReplOptions, +) *Repl { return &Repl{ source: source, target: target, nsFilter: nsFilter, catalog: catalog, + options: opts, pauseC: make(chan struct{}), doneSig: make(chan struct{}), } @@ -221,7 +237,7 @@ func (r *Repl) Start(ctx context.Context, startAt bson.Timestamp) error { return errors.Wrap(err, "major version") } - if topo.Support(targetVer).ClientBulkWrite() && !config.UseCollectionBulkWrite() { + if topo.Support(targetVer).ClientBulkWrite() && !r.options.UseCollectionBulkWrite { r.bulkWrite = newClientBulkWrite(config.BulkOpsSize, targetVer.Major() < 8) //nolint:mnd } else { r.bulkWrite = newCollectionBulkWrite(config.BulkOpsSize, targetVer.Major() < 8) //nolint:mnd diff --git a/tests/perf_test.go b/tests/perf_test.go index 0e2aa2e..407c1c4 100644 --- a/tests/perf_test.go +++ b/tests/perf_test.go @@ -27,7 +27,7 @@ func BenchmarkInsertOne(b *testing.B) { b.Fatal("no MongoDB URI provided") } - client, err := topo.Connect(b.Context(), mongodbURI) + client, err := topo.Connect(b.Context(), mongodbURI, &config.Config{}) if err != nil { b.Fatalf("Failed to connect to MongoDB: %v", err) } @@ -63,7 +63,8 @@ func BenchmarkReplaceOne(b *testing.B) { b.Fatal("no MongoDB URI provided") } - client, err := topo.Connect(b.Context(), mongodbURI) + cfg := &config.Config{} + client, err := topo.Connect(b.Context(), mongodbURI, cfg) if err != nil { b.Fatalf("Failed to connect to MongoDB: %v", err) } @@ -138,14 +139,15 @@ func performIndexTest(b *testing.B, opts performIndexTestOptions) { } ctx := b.Context() + cfg := &config.Config{} - source, err := topo.Connect(ctx, sourceURI) + source, err := topo.Connect(ctx, sourceURI, cfg) if err != nil { b.Fatalf("Failed to connect to MongoDB: %v", err) } defer source.Disconnect(ctx) //nolint:errcheck - target, err := topo.Connect(ctx, targetURI) + target, err := topo.Connect(ctx, targetURI, cfg) if err != nil { b.Fatalf("Failed to connect to MongoDB: %v", err) } diff --git a/topo/connect.go b/topo/connect.go index c140f61..32907bf 100644 --- a/topo/connect.go +++ b/topo/connect.go @@ -19,23 +19,8 @@ import ( "github.com/percona/percona-clustersync-mongodb/util" ) -type ConnectOptions struct { - Compressors []string -} - // Connect establishes a connection to a MongoDB instance using the provided URI. -// If the URI is empty, it returns an error. -func Connect(ctx context.Context, uri string) (*mongo.Client, error) { - return ConnectWithOptions(ctx, uri, &ConnectOptions{}) -} - -// ConnectWithOptions establishes a connection to a MongoDB instance using the provided URI and options. -// If the URI is empty, it returns an error. -func ConnectWithOptions( - ctx context.Context, - uri string, - connOpts *ConnectOptions, -) (*mongo.Client, error) { +func Connect(ctx context.Context, uri string, cfg *config.Config) (*mongo.Client, error) { if uri == "" { return nil, errors.New("invalid MongoDB URI") } @@ -58,10 +43,10 @@ func ConnectWithOptions( SetReadPreference(readpref.Primary()). SetReadConcern(readconcern.Majority()). SetWriteConcern(writeconcern.Majority()). - SetTimeout(config.OperationMongoDBCliTimeout()) + SetTimeout(cfg.MongoDB.OperationTimeout) - if connOpts != nil && connOpts.Compressors != nil { - opts.SetCompressors(connOpts.Compressors) + if uri == cfg.Target && len(cfg.MongoDB.TargetCompressors) > 0 { + opts.SetCompressors(cfg.MongoDB.TargetCompressors) } if config.MongoLogEnabled {