From cdbbae3e271b4bd3118407ccaba0fc48efb44438 Mon Sep 17 00:00:00 2001 From: Pavlo Golub Date: Mon, 29 Sep 2025 11:31:23 +0200 Subject: [PATCH 1/7] [!] add YAML-based chain definitions Introduces support for defining task chains in YAML format, enhancing the readability and maintainability of chain configurations. Users can now define chains and tasks in a structured YAML file and load them directly into database. Simplifies chain configuration compared to SQL inserts. Improves user experience with a human-readable format. Facilitates better version control and collaboration on chain definitions. --- README.md | 34 +++ docs/components.md | 27 +- docs/yaml-format.md | 135 +++++++++ docs/yaml-usage-guide.md | 400 ++++++++++++++++++++++++++ go.mod | 2 +- internal/config/cmdparser.go | 10 +- internal/pgengine/access.go | 8 +- internal/pgengine/bootstrap.go | 50 +++- internal/pgengine/pgengine_test.go | 3 +- internal/pgengine/transaction.go | 17 +- internal/pgengine/transaction_test.go | 15 +- internal/pgengine/types.go | 41 ++- internal/pgengine/yaml.go | 311 ++++++++++++++++++++ internal/pgengine/yaml_test.go | 194 +++++++++++++ internal/scheduler/chain.go | 8 +- internal/scheduler/chain_test.go | 3 +- mkdocs.yml | 2 + samples/yaml/Backup.yaml | 65 +++++ samples/yaml/Basic.yaml | 16 ++ samples/yaml/Chain.yaml | 40 +++ samples/yaml/CronStyle.yaml | 23 ++ samples/yaml/ETLPipeline.yaml | 69 +++++ samples/yaml/MultipleChains.yaml | 75 +++++ samples/yaml/Parameters.yaml | 51 ++++ samples/yaml/Shell.yaml | 24 ++ 25 files changed, 1564 insertions(+), 59 deletions(-) create mode 100644 docs/yaml-format.md create mode 100644 docs/yaml-usage-guide.md create mode 100644 internal/pgengine/yaml.go create mode 100644 internal/pgengine/yaml_test.go create mode 100644 samples/yaml/Backup.yaml create mode 100644 samples/yaml/Basic.yaml create mode 100644 samples/yaml/Chain.yaml create mode 100644 samples/yaml/CronStyle.yaml create mode 100644 samples/yaml/ETLPipeline.yaml create mode 100644 samples/yaml/MultipleChains.yaml create mode 100644 samples/yaml/Parameters.yaml create mode 100644 samples/yaml/Shell.yaml diff --git a/README.md b/README.md index a018d817..134a0236 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,40 @@ SELECT timetable.add_job('reindex-job', '0 0 * * 7', 'bash', - Full support for database driven logging - Enhanced cron-style scheduling - Optional concurrency protection +- **NEW**: YAML-based chain definitions for easy configuration + +## YAML Configuration + +You can now define chains using YAML files instead of SQL inserts, making configuration more readable and maintainable: + +```yaml +chains: + - name: "Daily ETL Pipeline" + schedule: "0 2 * * *" # 2 AM daily + live: true + max_instances: 1 + timeout: 3600000 # 1 hour + + tasks: + - name: "Extract data" + command: "SELECT extract_sales_data($1)" + parameters: ["yesterday"] + + - name: "Transform data" + command: "CALL transform_sales_data()" + autonomous: true + + - name: "Load to warehouse" + command: "CALL load_to_warehouse()" +``` + +Load YAML chains with: + +```bash +pg_timetable --file chains.yaml --connstr "postgresql://user:pass@host/db" +``` + +See [`samples/yaml/`](samples/yaml/) for more examples and [`docs/yaml-format.md`](docs/yaml-format.md) for complete format specification. ## Installation diff --git a/docs/components.md b/docs/components.md index b7e7eada..6e7f1c3a 100644 --- a/docs/components.md +++ b/docs/components.md @@ -13,12 +13,15 @@ The scheduling in **pg_timetable** encompasses three different abstraction level Currently, there are three different kinds of commands: ### `SQL` + SQL snippet. Starting a cleanup, refreshing a materialized view or processing data. ### `PROGRAM` + External Command. Anything that can be called as an external binary, including shells, e.g. `bash`, `pwsh`, etc. The external command will be called using golang's [exec.CommandContext](https://pkg.go.dev/os/exec#CommandContext). ### `BUILTIN` + Internal Command. A prebuilt functionality included in **pg_timetable**. These include: * *NoOp* @@ -78,42 +81,52 @@ In most cases, they have to be brought to live by passing input parameters to th Depending on the **command** kind argument can be represented by different *JSON* values. #### `SQL` + Schema: `array` Example: + ```sql '[ "one", 2, 3.14, false ]'::jsonb ``` #### `PROGRAM` + Schema: `array of strings` Example: + ```sql '["-x", "Latin-ASCII", "-o", "orte_ansi.txt", "orte.txt"]'::jsonb ``` #### `BUILTIN: Sleep` + Schema: `integer` Example: + ```sql '5' :: jsonb ``` #### `BUILTIN: Log` + Schema: `any` Examples: + ```sql '"WARNING"'::jsonb '{"Status": "WARNING"}'::jsonb ``` #### `BUILTIN: SendMail` + Schema: `object` Example: + ```sql '{ "username": "user@example.com", @@ -133,9 +146,11 @@ Example: ``` #### `BUILTIN: Download` + Schema: `object` Example: + ```sql '{ "workersnum": 2, @@ -145,9 +160,11 @@ Example: ``` #### `BUILTIN: CopyFromFile` + Schema: `object` Example: + ```sql '{ "sql": "COPY location FROM STDIN", @@ -156,9 +173,11 @@ Example: ``` #### `BUILTIN: CopyToFile` + Schema: `object` Example: + ```sql '{ "sql": "COPY location TO STDOUT", @@ -167,10 +186,12 @@ Example: ``` #### `BUILTIN: Shutdown` -*value ignored* + +value ignored #### `BUILTIN: NoOp` -*value ignored* + +value ignored ## Chain @@ -202,4 +223,4 @@ Once tasks have been arranged, they have to be scheduled as a **chain**. For thi -- Run VACUUM at 00:05 every day in August UTC SELECT timetable.add_job('execute-func', '5 0 * 8 *', 'VACUUM'); - ``` \ No newline at end of file + ``` diff --git a/docs/yaml-format.md b/docs/yaml-format.md new file mode 100644 index 00000000..abf076e7 --- /dev/null +++ b/docs/yaml-format.md @@ -0,0 +1,135 @@ +# YAML Chain Definition Format for pg_timetable + +This document defines the YAML format for defining chains of scheduled tasks in pg_timetable. + +## YAML Schema + +```yaml +# Top-level structure +chains: + - name: "chain-name" # Required: chain_name (TEXT, unique) + schedule: "* * * * *" # Required: run_at (cron format) + live: true # Optional: live (BOOLEAN), default: false + max_instances: 1 # Optional: max_instances (INTEGER) + timeout: 30000 # Optional: timeout in milliseconds (INTEGER) + self_destruct: false # Optional: self_destruct (BOOLEAN), default: false + exclusive: false # Optional: exclusive_execution (BOOLEAN), default: false + client_name: "worker-1" # Optional: client_name (TEXT) + on_error: "SELECT log_error()" # Optional: on_error SQL (TEXT) + + tasks: # Required: array of tasks + - name: "task-1" # Optional: task_name (TEXT) + kind: "SQL" # Optional: kind (SQL|PROGRAM|BUILTIN), default: SQL + command: "SELECT $1, $2" # Required: command (TEXT) + parameters: # Optional: parameters (array of execution parameters) + - ["value1", 42] # First execution with these parameters + - ["value2", 99] # Second execution with different parameters + run_as: "postgres" # Optional: run_as (TEXT) - role for SET ROLE + connect_string: "postgresql://user@host/otherdb" # Optional: database_connection (TEXT) + ignore_error: false # Optional: ignore_error (BOOLEAN), default: false + autonomous: false # Optional: autonomous (BOOLEAN), default: false + timeout: 5000 # Optional: timeout in milliseconds (INTEGER) + + - name: "task-2" + kind: "PROGRAM" + command: "bash" + parameters: ["-c", "echo hello"] + ignore_error: true +``` + +## Field Mappings + +### Chain Level + +| YAML Field | DB Column | Type | Default | Description | +|------------|-----------|------|---------|-------------| +| `name` | `chain_name` | TEXT | **required** | Unique chain identifier | +| `schedule` | `run_at` | cron | **required** | Cron-style schedule | +| `live` | `live` | BOOLEAN | `false` | Whether chain is active | +| `max_instances` | `max_instances` | INTEGER | `null` | Max parallel instances | +| `timeout` | `timeout` | INTEGER | `0` | Chain timeout (ms) | +| `self_destruct` | `self_destruct` | BOOLEAN | `false` | Delete after success | +| `exclusive` | `exclusive_execution` | BOOLEAN | `false` | Pause other chains | +| `client_name` | `client_name` | TEXT | `null` | Restrict to specific client | +| `on_error` | `on_error` | TEXT | `null` | Error handling SQL | + +### Task Level + +| YAML Field | DB Column | Type | Default | Description | +|------------|-----------|------|---------|-------------| +| `name` | `task_name` | TEXT | `null` | Task description | +| `kind` | `kind` | ENUM | `'SQL'` | Command type (SQL/PROGRAM/BUILTIN) | +| `command` | `command` | TEXT | **required** | Command to execute | +| `parameters` | via `timetable.parameter` | Array of JSONB | `null` | Array of parameter values, each causing separate task execution | +| `run_as` | `run_as` | TEXT | `null` | Role for SET ROLE | +| `connect_string` | `database_connection` | TEXT | `null` | Connection string | +| `ignore_error` | `ignore_error` | BOOLEAN | `false` | Continue on error | +| `autonomous` | `autonomous` | BOOLEAN | `false` | Execute outside transaction | +| `timeout` | `timeout` | INTEGER | `0` | Task timeout (ms) | + +## Task Ordering + +Tasks are ordered sequentially within a chain based on their array position. The system will automatically assign appropriate `task_order` values with spacing (e.g., 10, 20, 30) to allow future insertions. + +## Examples + +### Simple SQL Job + +```yaml +chains: + - name: "daily-report" + schedule: "0 9 * * *" # 9 AM daily + live: true + tasks: + - name: "generate-report" + command: "CALL generate_daily_report()" +``` + +### Multi-task Chain + +```yaml +chains: + - name: "etl-pipeline" + schedule: "0 2 * * *" # 2 AM daily + live: true + max_instances: 1 + timeout: 3600000 # 1 hour + + tasks: + - name: "extract-data" + command: "SELECT extract_sales_data($1)" + parameters: ["2023-01-01"] + + - name: "transform-data" + command: "CALL transform_sales_data()" + autonomous: true + + - name: "load-data" + command: "CALL load_to_warehouse()" + ignore_error: false +``` + +### Program Task + +```yaml +chains: + - name: "backup-job" + schedule: "0 3 * * 0" # Sunday 3 AM + live: true + + tasks: + - name: "pg-dump" + kind: "PROGRAM" + command: "pg_dump" + parameters: + - ["-h", "localhost", "-U", "postgres", "-d", "mydb", "-f", "/backups/mydb.sql"] +``` + +## Validation Rules + +1. **Required Fields**: `name`, `schedule`, `tasks`, and `command` for each task +2. **Unique Names**: Chain names must be unique across the database +3. **Valid Cron**: Schedule must be valid cron format (5 fields) +4. **Valid Kind**: Task kind must be one of: SQL, PROGRAM, BUILTIN +5. **Parameter Types**: Parameters must be strings or numbers (converted to JSONB array) +6. **Timeout Values**: Must be non-negative integers (milliseconds) diff --git a/docs/yaml-usage-guide.md b/docs/yaml-usage-guide.md new file mode 100644 index 00000000..d9c7f19c --- /dev/null +++ b/docs/yaml-usage-guide.md @@ -0,0 +1,400 @@ +# YAML Chain Configuration Guide + +This guide explains how to use YAML files to define pg_timetable chains as an alternative to SQL-based configuration. + +## Overview + +YAML chain definitions provide a human-readable way to create scheduled task chains without writing SQL. Benefits include: + +- Creating complex multi-step workflows with clear structure +- Version controlling your chain configurations +- Easy review and modification of scheduled tasks +- Sharing chain templates across environments + +## Basic Usage + +```bash +# Load YAML chains +pg_timetable --file chains.yaml postgresql://user:pass@host/db + +# Validate YAML without importing +pg_timetable --file chains.yaml --validate + +# Replace existing chains with same names +pg_timetable --file chains.yaml --replace postgresql://user:pass@host/db +``` + +## YAML Format + +### Basic Structure + +```yaml +chains: + - name: "chain-name" # Required: unique identifier + schedule: "* * * * *" # Required: cron format + live: true # Optional: enable/disable chain + max_instances: 1 # Optional: max parallel executions + timeout: 30000 # Optional: timeout in milliseconds + self_destruct: false # Optional: delete after success + exclusive: false # Optional: pause other chains while running + client_name: "worker-1" # Optional: restrict to specific client + on_error: "SELECT log_error($1)" # Optional: error handling SQL + tasks: # Required: array of tasks + - name: "task-name" # Optional: task description + kind: "SQL" # Optional: SQL, PROGRAM, or BUILTIN + command: "SELECT now()" # Required: command to execute + run_as: "postgres" # Optional: role for SET ROLE + connect_string: "postgresql://user@host/otherdb" # Optional: different database connection + ignore_error: false # Optional: continue on error + autonomous: false # Optional: run outside transaction + timeout: 5000 # Optional: task timeout in ms + parameters: # Optional: task parameters, each entry causes separate execution + - ["value1", 42] # Parameters for SQL tasks are arrays of values +``` + +### Task Parameters + +Each task can have multiple parameter entries, with each entry causing a separate execution: + +```yaml +# SQL task parameters (arrays of values) +- name: "sql-task" + kind: "SQL" + command: "SELECT $1, $2, $3, $4" + parameters: + - ["one", 2, 3.14, false] # First execution + - ["two", 4, 6.28, true] # Second execution + +# PROGRAM task parameters (arrays of command-line arguments) +- name: "program-task" + kind: "PROGRAM" + command: "iconv" + parameters: + - ["-x", "Latin-ASCII", "-o", "file1.txt", "input1.txt"] + - ["-x", "UTF-8", "-o", "file2.txt", "input2.txt"] + +# BUILTIN: Sleep task (integer values) +- name: "sleep-task" + kind: "BUILTIN" + command: "Sleep" + parameters: + - 5 # Sleep for 5 seconds + - 10 # Then sleep for 10 seconds + +# BUILTIN: Log task (string or object values) +- name: "log-task" + kind: "BUILTIN" + command: "Log" + parameters: + - "WARNING: Simple message" + - {"level": "WARNING", "details": "Object message"} + +# BUILTIN: SendMail task (complex object) +- name: "mail-task" + kind: "BUILTIN" + command: "SendMail" + parameters: + - username: "user@example.com" + password: "password123" + serverhost: "smtp.example.com" + serverport: 587 + senderaddr: "user@example.com" + toaddr: ["recipient@example.com"] + subject: "Notification" + msgbody: "

Hello User

" + contenttype: "text/html; charset=UTF-8" +``` + +### Examples + +#### Simple SQL Job + +```yaml +chains: + - name: "daily-cleanup" + schedule: "0 2 * * *" # 2 AM daily + live: true + + tasks: + - name: "vacuum-tables" + command: "VACUUM ANALYZE" +``` + +#### Multi-Step Chain + +```yaml +chains: + - name: "data-pipeline" + schedule: "0 1 * * *" # 1 AM daily + live: true + max_instances: 1 + timeout: 7200000 # 2 hours + + tasks: + - name: "extract" + command: | + CREATE TEMP TABLE temp_data AS + SELECT * FROM source_table + WHERE date >= CURRENT_DATE - INTERVAL '1 day' + + - name: "validate" + command: | + DO $$ + BEGIN + IF (SELECT COUNT(*) FROM temp_data) = 0 THEN + RAISE EXCEPTION 'No data to process'; + END IF; + END $$ + + - name: "transform" + command: "CALL transform_data_procedure()" + autonomous: true + + - name: "load" + command: "INSERT INTO target_table SELECT * FROM temp_data" +``` + +#### Program Tasks + +```yaml +chains: + - name: "backup-job" + schedule: "0 3 * * 0" # Sunday 3 AM + live: true + client_name: "backup-worker" + + tasks: + - name: "database-backup" + kind: "PROGRAM" + command: "pg_dump" + parameters: + - ["-h", "localhost", "-U", "postgres", "-d", "mydb", "-f", "/backups/mydb.sql"] + timeout: 3600000 # 1 hour + + - name: "compress-backup" + kind: "PROGRAM" + command: "gzip" + parameters: + - ["/backups/mydb.sql"] +``` + +#### Multiple Chains in One File + +```yaml +chains: + # Monitoring chain + - name: "health-check" + schedule: "*/15 * * * *" # Every 15 minutes + live: true + + tasks: + - command: "SELECT check_database_health()" + + # Cleanup chain + - name: "hourly-cleanup" + schedule: "0 * * * *" # Every hour + live: true + + tasks: + - command: "DELETE FROM logs WHERE created_at < now() - interval '7 days'" +``` + +## Advanced Features + +### Error Handling + +Control error behavior with `ignore_error` and `on_error`: + +```yaml +chains: + - name: "resilient-chain" + on_error: | + SELECT pg_notify('monitoring', + format('{"ConfigID": %s, "Message": "Something bad happened"}', + current_setting('pg_timetable.current_chain_id')::bigint)) + + tasks: + - name: "risky-task" + command: "SELECT might_fail()" + ignore_error: true # Continue chain execution even if this task fails + + - name: "cleanup-task" + command: "SELECT cleanup()" # Always runs, even if previous task failed +``` + +### Transaction Control + +Use `autonomous: true` for tasks that need to run outside the main transaction: + +```yaml +tasks: + - name: "vacuum-task" + command: "VACUUM FULL heavy_table" + autonomous: true # Required for VACUUM FULL + + - name: "create-database" + command: "CREATE DATABASE new_db" + autonomous: true # CREATE DATABASE requires autonomous transaction +``` + +### Remote Databases + +Execute tasks on different databases: + +```yaml +tasks: + - name: "cross-database-task" + command: "SELECT sync_data()" + connect_string: "postgresql://user:pass@other-host/other-db" +``` + +## Validation + +YAML files are validated when loaded: + +- **Syntax**: Valid YAML format +- **Structure**: Required fields present +- **Cron**: Valid 5-field cron expressions +- **Task kinds**: Must be SQL, PROGRAM, or BUILTIN +- **Timeouts**: Non-negative integers + +Use `--validate` to check files without importing: + +```bash +pg_timetable --file chains.yaml --validate +``` + +## Migration from SQL + +### Converting Existing Chains + +To convert SQL-based chains to YAML: + +1. **Query chain and tasks information**: + + ```sql + SELECT * + FROM timetable.chain c + WHERE c.chain_name = 'my-chain'; + + SELECT t.* + FROM timetable.task t JOIN + timetable.chain c ON t.chain_id = c.chain_id AND c.chain_name = 'my-chain' + ORDER BY t.task_order; + ``` + +2. **Map to YAML format**: + - `chain_name` → `name` + - `run_at` → `schedule` + - `live` → `live` + - `max_instances` → `max_instances` + - Task fields map directly + +3. **Test the conversion**: + + ```bash + pg_timetable --file converted.yaml --validate + ``` + +### Example Migration + +**Original SQL**: + +```sql +SELECT timetable.add_job( + job_name => 'daily-report', + job_schedule => '0 9 * * *', + job_command => 'CALL generate_report()', + job_live => TRUE +); +``` + +**Converted YAML**: + +```yaml +chains: + - name: "daily-report" + schedule: "0 9 * * *" + live: true + + tasks: + - command: "CALL generate_report()" +``` + +## Best Practices + +### Naming Conventions + +- Use descriptive, kebab-case names +- Include environment in name for clarity +- Group related chains in same file + +### Documentation + +- Use YAML comments to document complex logic +- Include purpose and dependencies in task names +- Document parameter meanings + +```yaml +chains: + - name: "etl-sales-data" + # Processes daily sales data from external API + # Depends on: external API availability, sales_raw table + schedule: "0 2 * * *" + + tasks: + - name: "extract-from-api" + # Fetches last 24h of sales data from REST API + command: "SELECT fetch_sales_data($1)" + parameters: ["yesterday"] +``` + +### Testing + +- Always validate YAML before deployment +- Test with `--validate` flag +- Use non-live chains for testing +- Keep backups of working configurations + +### Version Control + +- Store YAML files in version control +- Use meaningful commit messages +- Tag releases for production deployments +- Review changes before merging + +## Troubleshooting + +### Common Issues + +**Invalid YAML syntax**: + +```text +Error: failed to parse YAML: yaml: line 5: found character that cannot start any token +``` + +→ Check indentation and quotes + +**Invalid cron format**: + +```text +Error: invalid cron format: 0 9 * * (expected 5 fields) +``` + +→ Ensure cron has exactly 5 fields + +**Chain already exists**: + +```text +Error: chain 'my-chain' already exists (use --replace flag to overwrite) +``` + +→ Use `--replace` flag or choose different name + +**Missing required fields**: + +```text +Error: chain 1: chain name is required +``` + +→ Check all required fields are present diff --git a/go.mod b/go.mod index b2a9b814..2ca87030 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( github.com/testcontainers/testcontainers-go v0.39.0 github.com/testcontainers/testcontainers-go/modules/postgres v0.39.0 gopkg.in/natefinch/lumberjack.v2 v2.2.1 + gopkg.in/yaml.v3 v3.0.1 ) require ( @@ -85,5 +86,4 @@ require ( google.golang.org/grpc v1.75.0 // indirect google.golang.org/protobuf v1.36.7 // indirect gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/internal/config/cmdparser.go b/internal/config/cmdparser.go index 2ecc09dd..1c95eb9d 100644 --- a/internal/config/cmdparser.go +++ b/internal/config/cmdparser.go @@ -21,10 +21,12 @@ type LoggingOpts struct { // StartOpts specifies the application startup options type StartOpts struct { - File string `short:"f" long:"file" description:"SQL script file to execute during startup"` - Init bool `long:"init" description:"Initialize database schema to the latest version and exit. Can be used with --upgrade"` - Upgrade bool `long:"upgrade" description:"Upgrade database to the latest version"` - Debug bool `long:"debug" description:"Run in debug mode. Only asynchronous chains will be executed"` + File string `short:"f" long:"file" description:"SQL script or YAML chain definition file to execute during startup"` + Replace bool `long:"replace" description:"Replace existing chains when loading YAML files"` + Validate bool `long:"validate" description:"Only validate YAML file without importing chains"` + Init bool `long:"init" description:"Initialize database schema to the latest version and exit. Can be used with --upgrade"` + Upgrade bool `long:"upgrade" description:"Upgrade database to the latest version"` + Debug bool `long:"debug" description:"Run in debug mode. Only asynchronous chains will be executed"` } // ResourceOpts specifies the maximum resources available to application diff --git a/internal/pgengine/access.go b/internal/pgengine/access.go index cc11d54c..359b17db 100644 --- a/internal/pgengine/access.go +++ b/internal/pgengine/access.go @@ -40,7 +40,7 @@ func (pge *PgEngine) LogTaskExecution(ctx context.Context, task *ChainTask, retC _, err := pge.ConfigDb.Exec(ctx, `INSERT INTO timetable.execution_log ( chain_id, task_id, command, kind, last_run, finished, returncode, pid, output, client_name, txid, ignore_error) VALUES ($1, $2, $3, $4, clock_timestamp() - $5 :: interval, clock_timestamp(), $6, $7, NULLIF($8, ''), $9, $10, $11)`, - task.ChainID, task.TaskID, task.Script, task.Kind, + task.ChainID, task.TaskID, task.Command, task.Kind, fmt.Sprintf("%f seconds", float64(task.Duration)/1000000), retCode, pge.Getsid(), strings.TrimSpace(output), pge.ClientName, task.Vxid, task.IgnoreError) @@ -75,7 +75,7 @@ func (pge *PgEngine) RemoveChainRunStatus(ctx context.Context, chainID int) { // Select live chains with proper client_name value const sqlSelectLiveChains = `SELECT chain_id, chain_name, self_destruct, exclusive_execution, -COALESCE(max_instances, 16) as max_instances, COALESCE(timeout, 0) as timeout, on_error +COALESCE(max_instances, 16) as max_instances, COALESCE(timeout, 0) as timeout, COALESCE(on_error, '') as on_error FROM timetable.chain WHERE live AND (client_name = $1 or client_name IS NULL)` // SelectRebootChains returns a list of chains should be executed after reboot @@ -103,7 +103,7 @@ func (pge *PgEngine) SelectChains(ctx context.Context, dest *[]Chain) error { // SelectIntervalChains returns list of interval chains to be executed func (pge *PgEngine) SelectIntervalChains(ctx context.Context, dest *[]IntervalChain) error { const sqlSelectIntervalChains = `SELECT chain_id, chain_name, self_destruct, exclusive_execution, -COALESCE(max_instances, 16), COALESCE(timeout, 0), on_error, +COALESCE(max_instances, 16), COALESCE(timeout, 0), COALESCE(on_error, '') as on_error, EXTRACT(EPOCH FROM (substr(run_at, 7) :: interval)) :: int4 as interval_seconds, starts_with(run_at, '@after') as repeat_after FROM timetable.chain WHERE live AND (client_name = $1 or client_name IS NULL) AND substr(run_at, 1, 6) IN ('@every', '@after')` @@ -119,7 +119,7 @@ FROM timetable.chain WHERE live AND (client_name = $1 or client_name IS NULL) AN func (pge *PgEngine) SelectChain(ctx context.Context, dest *Chain, chainID int) error { // we accept not only live chains here because we want to run them in debug mode const sqlSelectSingleChain = `SELECT chain_id, chain_name, self_destruct, exclusive_execution, -COALESCE(timeout, 0) as timeout, COALESCE(max_instances, 16) as max_instances, on_error +COALESCE(timeout, 0) as timeout, COALESCE(max_instances, 16) as max_instances, COALESCE(on_error, '') as on_error FROM timetable.chain WHERE (client_name = $1 OR client_name IS NULL) AND chain_id = $2` rows, err := pge.ConfigDb.Query(ctx, sqlSelectSingleChain, pge.ClientName, chainID) if err != nil { diff --git a/internal/pgengine/bootstrap.go b/internal/pgengine/bootstrap.go index 1f362a54..8b84ff93 100644 --- a/internal/pgengine/bootstrap.go +++ b/internal/pgengine/bootstrap.go @@ -5,6 +5,7 @@ import ( "errors" "math/rand" "os" + "path/filepath" "strings" "time" @@ -108,7 +109,7 @@ func New(ctx context.Context, cmdOpts config.CmdOptions, logger log.LoggerHooker } pge.AddLogHook(ctx) //schema exists, we can log now if cmdOpts.Start.File != "" { - if err := pge.ExecuteCustomScripts(ctx, cmdOpts.Start.File); err != nil { + if err := pge.ExecuteFileScript(ctx, cmdOpts); err != nil { return nil, err } } @@ -217,6 +218,53 @@ func (pge *PgEngine) TryLockClientName(ctx context.Context, conn QueryRowIface) return nil } +// ExecuteFileScript handles both SQL and YAML files based on file extension +func (pge *PgEngine) ExecuteFileScript(ctx context.Context, cmdOpts config.CmdOptions) error { + filePath := cmdOpts.Start.File + + // Determine file type by extension + fileExt := strings.ToLower(filepath.Ext(filePath)) + + switch fileExt { + case ".yaml", ".yml": + // Handle YAML chain definition files + if cmdOpts.Start.Validate { + // Only validate, don't import + _, err := ParseYamlFile(filePath) + if err != nil { + pge.l.WithError(err).Error("YAML validation failed") + return err + } + pge.l.WithField("file", filePath).Info("YAML file validation successful") + return nil + } + + // Import YAML chains + return pge.LoadYamlChains(ctx, filePath, cmdOpts.Start.Replace) + + case ".sql": + // Handle SQL script files (existing behavior) + return pge.ExecuteCustomScripts(ctx, filePath) + + default: + // Try to detect content type for files without extension + content, err := os.ReadFile(filePath) + if err != nil { + pge.l.WithError(err).Error("cannot read file") + return err + } + + // Check if it looks like YAML (starts with "chains:" or contains YAML markers) + contentStr := strings.TrimSpace(string(content)) + if strings.HasPrefix(contentStr, "chains:") { + pge.l.WithField("file", filePath).Info("Detected YAML content, processing as YAML") + return pge.LoadYamlChains(ctx, filePath, false) + } + pge.l.WithField("file", filePath).Info("Processing as SQL script") + return pge.ExecuteCustomScripts(ctx, filePath) + } +} + // ExecuteCustomScripts executes SQL scripts in files func (pge *PgEngine) ExecuteCustomScripts(ctx context.Context, filename ...string) error { for _, f := range filename { diff --git a/internal/pgengine/pgengine_test.go b/internal/pgengine/pgengine_test.go index 670af9f2..1e4873a9 100644 --- a/internal/pgengine/pgengine_test.go +++ b/internal/pgengine/pgengine_test.go @@ -7,7 +7,6 @@ import ( "testing" "time" - pgtype "github.com/jackc/pgx/v5/pgtype" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -162,7 +161,7 @@ func TestGetRemoteDBTransaction(t *testing.T) { require.NoError(t, err, "remoteDB should be initialized") require.NotNil(t, remoteDb, "remoteDB should be initialized") - assert.NoError(t, pge.SetRole(ctx, remoteDb, pgtype.Text{String: "scheduler", Valid: true}), + assert.NoError(t, pge.SetRole(ctx, remoteDb, "scheduler"), "Set Role failed") assert.NotPanics(t, func() { pge.ResetRole(ctx, remoteDb) }, "Reset Role failed") pge.FinalizeDBConnection(ctx, remoteDb) diff --git a/internal/pgengine/transaction.go b/internal/pgengine/transaction.go index 5f7a0ff0..14814986 100644 --- a/internal/pgengine/transaction.go +++ b/internal/pgengine/transaction.go @@ -11,7 +11,6 @@ import ( "github.com/cybertec-postgresql/pg_timetable/internal/log" pgx "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/pgx/v5/pgtype" ) // StartTransaction returns transaction object, virtual transaction id and error @@ -77,11 +76,11 @@ func (pge *PgEngine) ExecLocalSQLTask(ctx context.Context, tx pgx.Tx, task *Chai pge.MustSavepoint(ctx, tx, task.TaskID) } pge.SetCurrentTaskContext(ctx, tx, task.ChainID, task.TaskID) - out, err = pge.ExecuteSQLCommand(ctx, tx, task.Script, paramValues) + out, err = pge.ExecuteSQLCommand(ctx, tx, task.Command, paramValues) if err != nil && task.IgnoreError { pge.MustRollbackToSavepoint(ctx, tx, task.TaskID) } - if task.RunAs.Valid { + if task.RunAs > "" { pge.ResetRole(ctx, tx) } return @@ -98,14 +97,14 @@ func (pge *PgEngine) ExecStandaloneTask(ctx context.Context, connf func() (PgxCo return "", err } pge.SetCurrentTaskContext(ctx, conn, task.ChainID, task.TaskID) - return pge.ExecuteSQLCommand(ctx, conn, task.Script, paramValues) + return pge.ExecuteSQLCommand(ctx, conn, task.Command, paramValues) } // ExecRemoteSQLTask executes task against remote connection func (pge *PgEngine) ExecRemoteSQLTask(ctx context.Context, task *ChainTask, paramValues []string) (string, error) { log.GetLogger(ctx).Info("Switching to remote task mode") return pge.ExecStandaloneTask(ctx, - func() (PgxConnIface, error) { return pge.GetRemoteDBConnection(ctx, task.ConnectString.String) }, + func() (PgxConnIface, error) { return pge.GetRemoteDBConnection(ctx, task.ConnectString) }, task, paramValues) } @@ -171,12 +170,12 @@ func (pge *PgEngine) FinalizeDBConnection(ctx context.Context, remoteDb PgxConnI } // SetRole - set the current user identifier of the current session -func (pge *PgEngine) SetRole(ctx context.Context, executor executor, runUID pgtype.Text) error { - if !runUID.Valid || strings.TrimSpace(runUID.String) == "" { +func (pge *PgEngine) SetRole(ctx context.Context, executor executor, runUID string) error { + if strings.TrimSpace(runUID) == "" { return nil } - log.GetLogger(ctx).Info("Setting role to ", runUID.String) - _, err := executor.Exec(ctx, fmt.Sprintf("SET ROLE %v", runUID.String)) + log.GetLogger(ctx).Info("Setting role to ", runUID) + _, err := executor.Exec(ctx, fmt.Sprintf("SET ROLE %v", runUID)) return err } diff --git a/internal/pgengine/transaction_test.go b/internal/pgengine/transaction_test.go index e12f81b3..9ac88e6e 100644 --- a/internal/pgengine/transaction_test.go +++ b/internal/pgengine/transaction_test.go @@ -8,7 +8,6 @@ import ( "github.com/cybertec-postgresql/pg_timetable/internal/pgengine" "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/pgx/v5/pgtype" "github.com/pashagolub/pgxmock/v4" "github.com/stretchr/testify/assert" ) @@ -94,7 +93,7 @@ func TestExecuteSQLTask(t *testing.T) { }) t.Run("Check remote SQL task", func(t *testing.T) { - task := pgengine.ChainTask{ConnectString: pgtype.Text{String: "foo", Valid: true}} + task := pgengine.ChainTask{ConnectString: "foo"} _, err := pge.ExecuteSQLTask(ctx, nil, &task, []string{}) assert.ErrorContains(t, err, "cannot parse") }) @@ -123,8 +122,8 @@ func TestExecLocalSQLTask(t *testing.T) { task := pgengine.ChainTask{ TaskID: 42, IgnoreError: true, - Script: "FOO", - RunAs: pgtype.Text{String: "Bob", Valid: true}, + Command: "FOO", + RunAs: "Bob", } _, err := pge.ExecLocalSQLTask(ctx, mockPool, &task, []string{}) assert.Error(t, err) @@ -148,8 +147,8 @@ func TestExecStandaloneTask(t *testing.T) { task := pgengine.ChainTask{ TaskID: 42, IgnoreError: true, - Script: "FOO", - RunAs: pgtype.Text{String: "Bob", Valid: true}, + Command: "FOO", + RunAs: "Bob", } cf := func() (pgengine.PgxConnIface, error) { return mockPool.AsConn(), nil } @@ -227,8 +226,8 @@ func TestSetRole(t *testing.T) { mockPool.ExpectExec("SET ROLE").WillReturnError(errors.New("error")) tx, err := mockPool.Begin(ctx) assert.NoError(t, err) - assert.Error(t, pge.SetRole(ctx, tx, pgtype.Text{String: "foo", Valid: true})) - assert.NoError(t, pge.SetRole(ctx, tx, pgtype.Text{String: "", Valid: false}), "Should ignore empty run_as") + assert.Error(t, pge.SetRole(ctx, tx, "foo")) + assert.NoError(t, pge.SetRole(ctx, tx, ""), "Should ignore empty run_as") mockPool.ExpectBegin() mockPool.ExpectExec("RESET ROLE").WillReturnError(errors.New("error")) diff --git a/internal/pgengine/types.go b/internal/pgengine/types.go index a7da2d70..daceea17 100644 --- a/internal/pgengine/types.go +++ b/internal/pgengine/types.go @@ -6,7 +6,6 @@ import ( "time" pgconn "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/pgx/v5/pgtype" ) type executor interface { @@ -15,13 +14,13 @@ type executor interface { // Chain structure used to represent tasks chains type Chain struct { - ChainID int `db:"chain_id"` - ChainName string `db:"chain_name"` - SelfDestruct bool `db:"self_destruct"` - ExclusiveExecution bool `db:"exclusive_execution"` - MaxInstances int `db:"max_instances"` - Timeout int `db:"timeout"` - OnErrorSQL pgtype.Text `db:"on_error"` + ChainID int `db:"chain_id" yaml:"-"` + ChainName string `db:"chain_name" yaml:"name"` + SelfDestruct bool `db:"self_destruct" yaml:"self_destruct,omitempty"` + ExclusiveExecution bool `db:"exclusive_execution" yaml:"exclusive,omitempty"` + MaxInstances int `db:"max_instances" yaml:"max_instances,omitempty"` + Timeout int `db:"timeout" yaml:"timeout,omitempty"` + OnError string `db:"on_error" yaml:"on_error,omitempty"` } // IntervalChain structure used to represent repeated chains. @@ -42,20 +41,20 @@ func (ichain IntervalChain) IsListed(ichains []IntervalChain) bool { // ChainTask structure describes each chain task type ChainTask struct { - ChainID int `db:"-"` - TaskID int `db:"task_id"` - Script string `db:"command"` - Kind string `db:"kind"` - RunAs pgtype.Text `db:"run_as"` - IgnoreError bool `db:"ignore_error"` - Autonomous bool `db:"autonomous"` - ConnectString pgtype.Text `db:"database_connection"` - Timeout int `db:"timeout"` // in milliseconds - StartedAt time.Time `db:"-"` - Duration int64 `db:"-"` // in microseconds - Vxid int64 `db:"-"` + ChainID int `db:"-" yaml:"-"` + TaskID int `db:"task_id" yaml:"-"` + Command string `db:"command" yaml:"command"` + Kind string `db:"kind" yaml:"kind,omitempty"` + RunAs string `db:"run_as" yaml:"run_as,omitempty"` + IgnoreError bool `db:"ignore_error" yaml:"ignore_error,omitempty"` + Autonomous bool `db:"autonomous" yaml:"autonomous,omitempty"` + ConnectString string `db:"database_connection" yaml:"connect_string,omitempty"` + Timeout int `db:"timeout" yaml:"timeout,omitempty"` // in milliseconds + StartedAt time.Time `db:"-" yaml:"-"` + Duration int64 `db:"-" yaml:"-"` // in microseconds + Vxid int64 `db:"-" yaml:"-"` } func (task *ChainTask) IsRemote() bool { - return task.ConnectString.Valid && strings.TrimSpace(task.ConnectString.String) != "" + return strings.TrimSpace(task.ConnectString) != "" } diff --git a/internal/pgengine/yaml.go b/internal/pgengine/yaml.go new file mode 100644 index 00000000..57f9aee8 --- /dev/null +++ b/internal/pgengine/yaml.go @@ -0,0 +1,311 @@ +package pgengine + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + + "gopkg.in/yaml.v3" +) + +// YamlChain represents a chain with tasks for YAML processing +type YamlChain struct { + Chain + ClientName string `db:"client_name" yaml:"client_name,omitempty"` + Schedule string `db:"run_at" yaml:"schedule,omitempty"` + Live bool `db:"live" yaml:"live,omitempty"` + Tasks []YamlTask `yaml:"tasks"` +} + +// YamlTask extends the basic task structure with Parameters field +type YamlTask struct { + ChainTask + TaskName string `db:"task_name" yaml:"name,omitempty"` + Parameters []interface{} `yaml:"parameters,omitempty"` +} + +// YamlConfig represents the root YAML configuration +type YamlConfig struct { + Chains []YamlChain `yaml:"chains"` +} + +// LoadYamlChains loads chains from a YAML file and imports them +func (pge *PgEngine) LoadYamlChains(ctx context.Context, filePath string, replace bool) error { + // Parse YAML file + yamlConfig, err := ParseYamlFile(filePath) + if err != nil { + return fmt.Errorf("failed to parse YAML file: %w", err) + } + + // Import chains + for _, yamlChain := range yamlConfig.Chains { + // Delete existing chain if replace mode + if replace { + _, _ = pge.ConfigDb.Exec(ctx, "SELECT timetable.delete_job($1)", yamlChain.ChainName) + } + + // Check if chain exists + var exists bool + err := pge.ConfigDb.QueryRow(ctx, + "SELECT EXISTS(SELECT 1 FROM timetable.chain WHERE chain_name = $1)", + yamlChain.ChainName).Scan(&exists) + if err != nil { + return fmt.Errorf("failed to check if chain exists: %w", err) + } + if exists && !replace { + return fmt.Errorf("chain '%s' already exists (use --replace flag to overwrite)", yamlChain.ChainName) + } + + // Use existing add_job function for single-task chains + if len(yamlChain.Tasks) == 1 { + task := yamlChain.Tasks[0] + params, _ := task.ToSQLParameters() + var paramsValue interface{} + if params == "" { + paramsValue = nil + } else { + paramsValue = params + } + + _, err := pge.ConfigDb.Exec(ctx, ` + SELECT timetable.add_job($1, $2, $3, $4::jsonb, $5::timetable.command_kind, $6, $7, $8, $9, $10, $11, $12)`, + yamlChain.ChainName, + yamlChain.Schedule, + task.Command, + paramsValue, + task.Kind, + nullString(yamlChain.ClientName), + yamlChain.MaxInstances, + yamlChain.Live, + yamlChain.SelfDestruct, + task.IgnoreError, + yamlChain.ExclusiveExecution, + nullString(yamlChain.OnError)) + if err != nil { + return fmt.Errorf("failed to create chain %s: %w", yamlChain.ChainName, err) + } + } else { + // Multi-task chain - use direct SQL + chainID, err := pge.createChainFromYaml(ctx, &yamlChain) + if err != nil { + return fmt.Errorf("failed to create multi-task chain %s: %w", yamlChain.ChainName, err) + } + pge.l.WithField("chain", yamlChain.ChainName).WithField("chain_id", chainID).Info("Created multi-task chain") + } + } + + pge.l.WithField("chains", len(yamlConfig.Chains)).WithField("file", filePath).Info("Successfully imported YAML chains") + return nil +} + +// createChainFromYaml creates a multi-task chain using direct SQL inserts +func (pge *PgEngine) createChainFromYaml(ctx context.Context, yamlChain *YamlChain) (int64, error) { + // Insert chain + var chainID int64 + err := pge.ConfigDb.QueryRow(ctx, ` + INSERT INTO timetable.chain ( + chain_name, run_at, max_instances, timeout, live, + self_destruct, exclusive_execution, client_name, on_error + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + RETURNING chain_id`, + yamlChain.ChainName, + yamlChain.Schedule, + yamlChain.MaxInstances, + yamlChain.Timeout, + yamlChain.Live, + yamlChain.SelfDestruct, + yamlChain.ExclusiveExecution, + nullString(yamlChain.ClientName), + nullString(yamlChain.OnError)).Scan(&chainID) + if err != nil { + return 0, fmt.Errorf("failed to insert chain: %w", err) + } + + // Insert tasks + for i, task := range yamlChain.Tasks { + taskOrder := float64((i + 1) * 10) + + var taskID int64 + err := pge.ConfigDb.QueryRow(ctx, ` + INSERT INTO timetable.task ( + chain_id, task_order, task_name, kind, command, + run_as, database_connection, ignore_error, autonomous, timeout + ) VALUES ($1, $2, $3, $4::timetable.command_kind, $5, $6, $7, $8, $9, $10) + RETURNING task_id`, + chainID, + taskOrder, + nullString(task.TaskName), + task.Kind, + task.Command, + nullString(task.RunAs), + nullString(task.ConnectString), + task.IgnoreError, + task.Autonomous, + task.Timeout).Scan(&taskID) + if err != nil { + return 0, fmt.Errorf("failed to insert task %d: %w", i+1, err) + } + + // Insert parameters if any + if len(task.Parameters) > 0 { + params, err := task.ToSQLParameters() + if err != nil { + return 0, fmt.Errorf("failed to convert parameters: %w", err) + } + _, err = pge.ConfigDb.Exec(ctx, + "INSERT INTO timetable.parameter (task_id, order_id, value) VALUES ($1, 1, $2::jsonb)", + taskID, params) + if err != nil { + return 0, fmt.Errorf("failed to insert parameters: %w", err) + } + } + } + + return chainID, nil +} + +// nullString returns nil for empty strings, otherwise returns the string +func nullString(s string) interface{} { + if s == "" { + return nil + } + return s +} + +// ValidateChain validates a YAML chain configuration +func (c *YamlChain) ValidateChain() error { + if c.ChainName == "" { + return fmt.Errorf("chain name is required") + } + + if c.Schedule == "" { + return fmt.Errorf("chain schedule is required") + } + + // Validate cron format + switch c.Schedule { + case "", "@reboot", "@after", "@every": + // Valid special schedules + default: + fields := strings.Fields(c.Schedule) + if len(fields) != 5 { + return fmt.Errorf("invalid cron format: %s (expected 5 fields)", c.Schedule) + } + } + + if len(c.Tasks) == 0 { + return fmt.Errorf("chain must have at least one task") + } + + // Validate each task + for i, task := range c.Tasks { + if err := task.ValidateTask(); err != nil { + return fmt.Errorf("task %d: %w", i+1, err) + } + } + + return nil +} + +// ValidateTask validates a YAML task configuration +func (t *YamlTask) ValidateTask() error { + if t.Command == "" { + return fmt.Errorf("task command is required") + } + + // Validate kind + switch strings.ToUpper(t.Kind) { + case "", "SQL", "PROGRAM", "BUILTIN": + // Valid kinds + default: + return fmt.Errorf("invalid task kind: %s (must be SQL, PROGRAM, or BUILTIN)", t.Kind) + } + + // Validate timeout is non-negative + if t.Timeout < 0 { + return fmt.Errorf("task timeout must be non-negative") + } + + return nil +} + +// SetDefaults sets default values for optional fields +func (c *YamlChain) SetDefaults() { + // Chain defaults + if c.Schedule == "" { + c.Schedule = "* * * * *" // Default to every minute + } + + // Task defaults + for i := range c.Tasks { + task := &c.Tasks[i] + if task.Kind == "" { + task.Kind = "SQL" + } + } +} + +// ParseYamlFile parses a YAML file and returns the configuration +func ParseYamlFile(filePath string) (*YamlConfig, error) { + // Check if file exists + if _, err := os.Stat(filePath); os.IsNotExist(err) { + return nil, fmt.Errorf("file not found: %s", filePath) + } + + // Check file extension + ext := strings.ToLower(filepath.Ext(filePath)) + if ext != ".yaml" && ext != ".yml" { + return nil, fmt.Errorf("file must have .yaml or .yml extension: %s", filePath) + } + + // Read file + data, err := os.ReadFile(filePath) + if err != nil { + return nil, fmt.Errorf("failed to read file: %w", err) + } + + // Parse YAML + var config YamlConfig + if err := yaml.Unmarshal(data, &config); err != nil { + return nil, fmt.Errorf("failed to parse YAML: %w", err) + } + + // Set defaults and validate + for i := range config.Chains { + chain := &config.Chains[i] + chain.SetDefaults() + if err := chain.ValidateChain(); err != nil { + return nil, fmt.Errorf("chain %d (%s): %w", i+1, chain.ChainName, err) + } + } + + return &config, nil +} + +// ToSQLParameters converts YAML parameters to SQL-compatible format +func (t *YamlTask) ToSQLParameters() (string, error) { + if len(t.Parameters) == 0 { + return "", nil + } + + // Convert to JSON array format for PostgreSQL + params := make([]string, len(t.Parameters)) + for i, param := range t.Parameters { + switch v := param.(type) { + case string: + params[i] = fmt.Sprintf(`"%s"`, strings.ReplaceAll(v, `"`, `\"`)) + case int, int32, int64: + params[i] = fmt.Sprintf("%v", v) + case float32, float64: + params[i] = fmt.Sprintf("%v", v) + case bool: + params[i] = fmt.Sprintf("%t", v) + default: + params[i] = fmt.Sprintf(`"%v"`, v) + } + } + + return fmt.Sprintf("[%s]", strings.Join(params, ", ")), nil +} diff --git a/internal/pgengine/yaml_test.go b/internal/pgengine/yaml_test.go new file mode 100644 index 00000000..76f9c443 --- /dev/null +++ b/internal/pgengine/yaml_test.go @@ -0,0 +1,194 @@ +package pgengine_test + +import ( + "context" + "os" + "testing" + + "github.com/cybertec-postgresql/pg_timetable/internal/pgengine" + "github.com/cybertec-postgresql/pg_timetable/internal/testutils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// Helper function to create temporary YAML file +func createTempYamlFile(t *testing.T, content string) string { + tmpfile, err := os.CreateTemp("", "test-*.yaml") + require.NoError(t, err) + + _, err = tmpfile.Write([]byte(content)) + require.NoError(t, err) + + err = tmpfile.Close() + require.NoError(t, err) + + return tmpfile.Name() +} + +// Helper function to remove temporary file +func removeTempFile(t *testing.T, filePath string) { + err := os.Remove(filePath) + require.NoError(t, err) +} + +func TestLoadYamlChainsIntegration(t *testing.T) { + container, cleanup := testutils.SetupPostgresContainer(t) + defer cleanup() + + ctx := context.Background() + pge := container.Engine + + t.Run("Single task chain", func(t *testing.T) { + // Create a simple YAML chain config + yamlContent := `chains: + - name: test-single-task + schedule: "0 0 * * *" + tasks: + - command: SELECT 1 + kind: SQL` + + // Create temporary YAML file + tempFile := createTempYamlFile(t, yamlContent) + defer removeTempFile(t, tempFile) + + // Load the chain + err := pge.LoadYamlChains(ctx, tempFile, false) + require.NoError(t, err) + + // Verify the chain was created + var count int + err = pge.ConfigDb.QueryRow(ctx, + "SELECT COUNT(*) FROM timetable.chain WHERE chain_name = $1", + "test-single-task").Scan(&count) + require.NoError(t, err) + assert.Equal(t, 1, count) + }) + + t.Run("Replace existing chain", func(t *testing.T) { + chainName := "test-replace-chain" + + // First, create a chain + yamlContent1 := `chains: + - name: test-replace-chain + schedule: "0 0 * * *" + tasks: + - command: SELECT 1 + kind: SQL` + + tempFile1 := createTempYamlFile(t, yamlContent1) + defer removeTempFile(t, tempFile1) + + err := pge.LoadYamlChains(ctx, tempFile1, false) + require.NoError(t, err) + + // Now replace it with a different chain + yamlContent2 := `chains: + - name: test-replace-chain + schedule: "0 1 * * *" + tasks: + - command: SELECT 2 + kind: SQL` + + tempFile2 := createTempYamlFile(t, yamlContent2) + defer removeTempFile(t, tempFile2) + + // Should succeed with replace=true + err = pge.LoadYamlChains(ctx, tempFile2, true) + require.NoError(t, err) + + // Verify the schedule was updated + var schedule string + err = pge.ConfigDb.QueryRow(ctx, + "SELECT run_at FROM timetable.chain WHERE chain_name = $1", + chainName).Scan(&schedule) + require.NoError(t, err) + assert.Equal(t, "0 1 * * *", schedule) + }) +} + +func TestYamlParameterHandling(t *testing.T) { + // Test parsing and validation of different parameter formats + yamlContent := `chains: + - name: "test-parameters" + schedule: "0 0 * * *" + tasks: + - name: "sql-test" + kind: "SQL" + command: "SELECT $1, $2, $3" + parameters: + - ["value1", 42, true] + - ["value2", 99, false] + + - name: "program-test" + kind: "PROGRAM" + command: "echo" + parameters: + - ["-n", "hello world"] + - ["goodbye"] + + - name: "sleep-test" + kind: "BUILTIN" + command: "Sleep" + parameters: + - 5 + - 10 + + - name: "log-test" + kind: "BUILTIN" + command: "Log" + parameters: + - "warning message" + - {"level": "WARNING", "message": "test"} +` + + // Create temporary file with content + tmpfile, err := os.CreateTemp("", "test-*.yaml") + require.NoError(t, err) + defer os.Remove(tmpfile.Name()) + + _, err = tmpfile.Write([]byte(yamlContent)) + require.NoError(t, err) + err = tmpfile.Close() + require.NoError(t, err) + + // Parse the YAML + yamlConfig, err := pgengine.ParseYamlFile(tmpfile.Name()) + require.NoError(t, err) + + // Check parsed content + require.Equal(t, 1, len(yamlConfig.Chains)) + chain := yamlConfig.Chains[0] + require.Equal(t, "test-parameters", chain.ChainName) + require.Equal(t, 4, len(chain.Tasks)) + + // Check SQL task parameters + sqlTask := chain.Tasks[0] + require.Equal(t, "SQL", sqlTask.Kind) + require.Equal(t, 2, len(sqlTask.Parameters)) + sqlParam1, ok := sqlTask.Parameters[0].([]interface{}) + require.True(t, ok, "SQL parameter should be an array") + assert.Equal(t, 3, len(sqlParam1)) + + // Check PROGRAM task parameters + programTask := chain.Tasks[1] + require.Equal(t, "PROGRAM", programTask.Kind) + require.Equal(t, 2, len(programTask.Parameters)) + + // Check BUILTIN Sleep task parameters + sleepTask := chain.Tasks[2] + require.Equal(t, "BUILTIN", sleepTask.Kind) + require.Equal(t, "Sleep", sleepTask.Command) + require.Equal(t, 2, len(sleepTask.Parameters)) + sleepParam1, ok := sleepTask.Parameters[0].(int) + require.True(t, ok, "Sleep parameter should be an integer") + assert.Equal(t, 5, sleepParam1) + + // Check BUILTIN Log task parameters + logTask := chain.Tasks[3] + require.Equal(t, "BUILTIN", logTask.Kind) + require.Equal(t, "Log", logTask.Command) + require.Equal(t, 2, len(logTask.Parameters)) + logParam1, ok := logTask.Parameters[0].(string) + require.True(t, ok, "Log parameter can be a string") + assert.Equal(t, "warning message", logParam1) +} diff --git a/internal/scheduler/chain.go b/internal/scheduler/chain.go index da82c45d..fa52951a 100644 --- a/internal/scheduler/chain.go +++ b/internal/scheduler/chain.go @@ -176,12 +176,12 @@ func getTimeoutContext(ctx context.Context, globalTimeout int, customTimeout int } func (sch *Scheduler) executeOnErrorHandler(ctx context.Context, chain Chain) { - if ctx.Err() != nil || !chain.OnErrorSQL.Valid { + if ctx.Err() != nil || chain.OnError == "" { return } l := sch.l.WithField("chain", chain.ChainID) l.Info("Starting error handling") - if _, err := sch.pgengine.ConfigDb.Exec(ctx, chain.OnErrorSQL.String); err != nil { + if _, err := sch.pgengine.ConfigDb.Exec(ctx, chain.OnError); err != nil { l.Info("Error handler failed") return } @@ -277,9 +277,9 @@ func (sch *Scheduler) executeTask(ctx context.Context, tx pgx.Tx, task *pgengine l.Info("Program task execution skipped") return -2 } - retCode, out, err = sch.ExecuteProgramCommand(ctx, task.Script, paramValues) + retCode, out, err = sch.ExecuteProgramCommand(ctx, task.Command, paramValues) case "BUILTIN": - out, err = sch.executeBuiltinTask(ctx, task.Script, paramValues) + out, err = sch.executeBuiltinTask(ctx, task.Command, paramValues) } task.Duration = time.Since(task.StartedAt).Microseconds() diff --git a/internal/scheduler/chain_test.go b/internal/scheduler/chain_test.go index 8524b481..71b7272d 100644 --- a/internal/scheduler/chain_test.go +++ b/internal/scheduler/chain_test.go @@ -12,7 +12,6 @@ import ( "github.com/cybertec-postgresql/pg_timetable/internal/pgengine" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/pgx/v5/pgtype" "github.com/pashagolub/pgxmock/v4" "github.com/stretchr/testify/assert" ) @@ -118,7 +117,7 @@ func TestExecuteChainElement(t *testing.T) { } func TestExecuteOnErrorHandler(t *testing.T) { - c := Chain{ChainID: 42, OnErrorSQL: pgtype.Text{String: "FOO", Valid: true}} + c := Chain{ChainID: 42, OnError: "FOO"} mock, err := pgxmock.NewPool() assert.NoError(t, err) pge := pgengine.NewDB(mock, "-c", "scheduler_unit_test", "--password=somestrong") diff --git a/mkdocs.yml b/mkdocs.yml index c149f4c4..f5407c39 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -66,8 +66,10 @@ nav: - Basic Jobs: basic_jobs.md - Migration: migration.md - Samples: samples.md + - YAML Chains: yaml-usage-guide.md - Reference: - API: api.md + - YAML Schema: yaml-format.md - Developer: - License: license.md - Documentation: ../devel/godoc/index.html diff --git a/samples/yaml/Backup.yaml b/samples/yaml/Backup.yaml new file mode 100644 index 00000000..b34b4239 --- /dev/null +++ b/samples/yaml/Backup.yaml @@ -0,0 +1,65 @@ +# Backup job example with program tasks + +chains: + - name: "Database Backup Job" + schedule: "0 3 * * 0" # Sunday at 3 AM + live: true + max_instances: 1 + timeout: 7200000 # 2 hour timeout + client_name: "backup-worker" + + tasks: + - name: "Create backup directory" + kind: "PROGRAM" + command: "mkdir" + parameters: + - "-p" + - "/backups/$(date +%Y-%m-%d)" + ignore_error: true + + - name: "Database dump" + kind: "PROGRAM" + command: "pg_dump" + parameters: + - "-h" + - "localhost" + - "-U" + - "postgres" + - "-d" + - "mydb" + - "-f" + - "/backups/$(date +%Y-%m-%d)/mydb_backup.sql" + - "--verbose" + timeout: 5400000 # 90 minutes for dump + + - name: "Compress backup" + kind: "PROGRAM" + command: "gzip" + parameters: + - "/backups/$(date +%Y-%m-%d)/mydb_backup.sql" + timeout: 600000 # 10 minutes for compression + + - name: "Cleanup old backups" + kind: "PROGRAM" + command: "find" + parameters: + - "/backups" + - "-type" + - "f" + - "-name" + - "*.sql.gz" + - "-mtime" + - "+7" # older than 7 days + - "-delete" + ignore_error: true + + - name: "Log backup completion" + kind: "SQL" + command: | + INSERT INTO backup_log (backup_date, backup_type, status, file_path) + VALUES ( + CURRENT_DATE, + 'full_database', + 'completed', + '/backups/' || TO_CHAR(CURRENT_DATE, 'YYYY-MM-DD') || '/mydb_backup.sql.gz' + ) \ No newline at end of file diff --git a/samples/yaml/Basic.yaml b/samples/yaml/Basic.yaml new file mode 100644 index 00000000..3b78d66f --- /dev/null +++ b/samples/yaml/Basic.yaml @@ -0,0 +1,16 @@ +# Basic YAML chain example +# Equivalent to samples/Basic.sql + +chains: + - name: "notify every minute" + schedule: "* * * * *" # Every minute + live: true + max_instances: 1 + self_destruct: false + + tasks: + - name: "Send notification" + kind: "SQL" + command: "SELECT pg_notify($1, $2)" + parameters: ["TT_CHANNEL", "Ahoj from YAML base task"] + ignore_error: true \ No newline at end of file diff --git a/samples/yaml/Chain.yaml b/samples/yaml/Chain.yaml new file mode 100644 index 00000000..8fbfcd89 --- /dev/null +++ b/samples/yaml/Chain.yaml @@ -0,0 +1,40 @@ +# Multi-task chain example +# Equivalent to samples/Chain.sql but simplified + +chains: + - name: "chain operation" + schedule: "* * * * *" # Every minute + live: true + max_instances: 1 + self_destruct: false + exclusive: false + + tasks: + - name: "Initialize chain log table" + kind: "SQL" + command: | + CREATE TABLE IF NOT EXISTS timetable.chain_log ( + chain_log BIGSERIAL, + EVENT TEXT, + time TIMESTAMPTZ, + PRIMARY KEY (chain_log) + ) + ignore_error: true + + - name: "Log chain start" + kind: "SQL" + command: "INSERT INTO timetable.chain_log (EVENT, time) VALUES ($1, CURRENT_TIMESTAMP)" + parameters: ["Chain started"] + ignore_error: true + + - name: "Log chain processing" + kind: "SQL" + command: "INSERT INTO timetable.chain_log (EVENT, time) VALUES ($1, CURRENT_TIMESTAMP)" + parameters: ["Chain processing"] + ignore_error: true + + - name: "Log chain completion" + kind: "SQL" + command: "INSERT INTO timetable.chain_log (EVENT, time) VALUES ($1, CURRENT_TIMESTAMP)" + parameters: ["Chain completed"] + ignore_error: true \ No newline at end of file diff --git a/samples/yaml/CronStyle.yaml b/samples/yaml/CronStyle.yaml new file mode 100644 index 00000000..298323cf --- /dev/null +++ b/samples/yaml/CronStyle.yaml @@ -0,0 +1,23 @@ +# Cron-style scheduling example +# Equivalent to samples/CronStyle.sql + +chains: + - name: "cron_Job run after 40th minutes after 2 hour on 27th of every month" + schedule: "40 */2 27 * *" # 40th minute, every 2 hours, on 27th of month + live: true + + tasks: + - name: "Create dummy log table" + kind: "SQL" + command: | + CREATE TABLE IF NOT EXISTS timetable.dummy_log ( + log_ID BIGSERIAL, + event_name TEXT, + timestmp TIMESTAMPTZ DEFAULT TRANSACTION_TIMESTAMP(), + PRIMARY KEY (log_ID) + ) + ignore_error: true + + - name: "Insert cron test event" + kind: "SQL" + command: "INSERT INTO timetable.dummy_log (event_name) VALUES ('Cron test')" \ No newline at end of file diff --git a/samples/yaml/ETLPipeline.yaml b/samples/yaml/ETLPipeline.yaml new file mode 100644 index 00000000..8f537dae --- /dev/null +++ b/samples/yaml/ETLPipeline.yaml @@ -0,0 +1,69 @@ +# ETL Pipeline example - comprehensive multi-step chain + +chains: + - name: "ETL Pipeline" + schedule: "0 2 * * *" # Daily at 2 AM + live: true + max_instances: 1 + timeout: 3600000 # 1 hour timeout + + tasks: + - name: "Extract sales data" + kind: "SQL" + command: | + CREATE TEMP TABLE temp_sales_extract AS + SELECT * FROM sales_raw + WHERE created_date >= CURRENT_DATE - INTERVAL '1 day' + + - name: "Validate extracted data" + kind: "SQL" + command: | + DO $$ + DECLARE + row_count INTEGER; + BEGIN + SELECT COUNT(*) INTO row_count FROM temp_sales_extract; + IF row_count = 0 THEN + RAISE EXCEPTION 'No data extracted for processing'; + END IF; + RAISE NOTICE 'Extracted % rows for processing', row_count; + END $$ + + - name: "Transform data" + kind: "SQL" + command: | + CREATE TEMP TABLE temp_sales_transformed AS + SELECT + sales_id, + customer_id, + UPPER(TRIM(customer_name)) as customer_name, + product_id, + ROUND(amount, 2) as amount, + created_date + FROM temp_sales_extract + WHERE amount > 0 + autonomous: true + + - name: "Load to data warehouse" + kind: "SQL" + command: | + INSERT INTO sales_warehouse + (sales_id, customer_id, customer_name, product_id, amount, processed_date) + SELECT + sales_id, customer_id, customer_name, product_id, amount, CURRENT_DATE + FROM temp_sales_transformed + ON CONFLICT (sales_id) DO UPDATE SET + amount = EXCLUDED.amount, + processed_date = EXCLUDED.processed_date + + - name: "Update processing log" + kind: "SQL" + command: | + INSERT INTO etl_log (process_name, process_date, rows_processed, status) + SELECT + 'ETL Pipeline', + CURRENT_DATE, + COUNT(*), + 'SUCCESS' + FROM temp_sales_transformed + ignore_error: false \ No newline at end of file diff --git a/samples/yaml/MultipleChains.yaml b/samples/yaml/MultipleChains.yaml new file mode 100644 index 00000000..f9d9f61d --- /dev/null +++ b/samples/yaml/MultipleChains.yaml @@ -0,0 +1,75 @@ +# Multiple chains in a single file example + +chains: + # Daily monitoring chain + - name: "Daily Health Check" + schedule: "0 8 * * *" # Daily at 8 AM + live: true + max_instances: 1 + + tasks: + - name: "Check database connections" + kind: "SQL" + command: | + SELECT + 'Database connection check: ' || + CASE WHEN COUNT(*) > 0 THEN 'OK' ELSE 'FAIL' END as status + FROM pg_stat_activity + WHERE state = 'active' + + - name: "Check disk space" + kind: "PROGRAM" + command: "df" + parameters: ["-h", "/var/lib/postgresql"] + ignore_error: true + + - name: "Log health check" + kind: "SQL" + command: | + INSERT INTO health_check_log (check_date, status) + VALUES (CURRENT_DATE, 'completed') + + # Hourly cleanup chain + - name: "Hourly Cleanup" + schedule: "0 * * * *" # Every hour + live: true + max_instances: 1 + timeout: 300000 # 5 minutes + + tasks: + - name: "Clean temporary tables" + kind: "SQL" + command: | + DROP TABLE IF EXISTS temp_sales_extract; + DROP TABLE IF EXISTS temp_sales_transformed; + ignore_error: true + + - name: "Vacuum analyze stats tables" + kind: "SQL" + command: "VACUUM ANALYZE pg_stat_user_tables" + autonomous: true + ignore_error: true + + # Weekly report chain + - name: "Weekly Sales Report" + schedule: "0 9 * * 1" # Monday at 9 AM + live: true + max_instances: 1 + timeout: 1800000 # 30 minutes + + tasks: + - name: "Generate weekly sales summary" + kind: "SQL" + command: | + INSERT INTO weekly_reports (report_date, total_sales, total_orders) + SELECT + DATE_TRUNC('week', CURRENT_DATE), + SUM(amount), + COUNT(*) + FROM sales_warehouse + WHERE created_date >= DATE_TRUNC('week', CURRENT_DATE) - INTERVAL '1 week' + AND created_date < DATE_TRUNC('week', CURRENT_DATE) + + - name: "Send report notification" + kind: "SQL" + command: "SELECT pg_notify('reports', 'Weekly sales report generated')" \ No newline at end of file diff --git a/samples/yaml/Parameters.yaml b/samples/yaml/Parameters.yaml new file mode 100644 index 00000000..c8a340df --- /dev/null +++ b/samples/yaml/Parameters.yaml @@ -0,0 +1,51 @@ +chains: + - name: "parameter-examples" + schedule: "0 0 * * *" + live: true + + tasks: + # SQL task with parameters + - name: "sql-example" + kind: "SQL" + command: "SELECT $1, $2, $3, $4" + parameters: + - ["one", 2, 3.14, false] # First execution with these parameters + - ["two", 4, 6.28, true] # Second execution with these parameters + + # PROGRAM task with parameters + - name: "program-example" + kind: "PROGRAM" + command: "iconv" + parameters: + - ["-x", "Latin-ASCII", "-o", "file1.txt", "input1.txt"] # First execution + - ["-x", "UTF-8", "-o", "file2.txt", "input2.txt"] # Second execution + + # BUILTIN: Sleep task + - name: "sleep-example" + kind: "BUILTIN" + command: "Sleep" + parameters: + - 5 # Sleep for 5 seconds + + # BUILTIN: Log task + - name: "log-example" + kind: "BUILTIN" + command: "Log" + parameters: + - "WARNING: This is a test message" # Simple string message + - {"Status": "WARNING", "Details": "Test object"} # Object message + + # BUILTIN: SendMail task + - name: "email-example" + kind: "BUILTIN" + command: "SendMail" + parameters: + - username: "user@example.com" + password: "password" + serverhost: "smtp.example.com" + serverport: 587 + senderaddr: "user@example.com" + toaddr: ["recipient@example.com"] + subject: "pg_timetable - Test Email" + msgbody: "

Hello User,

This is a test email.

" + contenttype: "text/html; charset=UTF-8" \ No newline at end of file diff --git a/samples/yaml/Shell.yaml b/samples/yaml/Shell.yaml new file mode 100644 index 00000000..48a149ea --- /dev/null +++ b/samples/yaml/Shell.yaml @@ -0,0 +1,24 @@ +# Shell/Program task example +# Equivalent to samples/Shell.sql + +chains: + - name: "psql chain" + schedule: "* * * * *" # Every minute + live: true + + tasks: + - name: "Run psql command" + kind: "PROGRAM" + command: "psql" + parameters: + - "-h" + - "localhost" # Will be replaced with actual host at runtime + - "-p" + - "5432" # Will be replaced with actual port at runtime + - "-d" + - "postgres" # Will be replaced with actual database at runtime + - "-U" + - "postgres" # Will be replaced with actual user at runtime + - "-w" + - "-c" + - "SELECT now();" \ No newline at end of file From 464bdfe8286ed2e169f9592314642c9a39f7f332 Mon Sep 17 00:00:00 2001 From: Pavlo Golub Date: Mon, 29 Sep 2025 17:45:22 +0200 Subject: [PATCH 2/7] fix tests --- internal/pgengine/pgengine_test.go | 3 +++ internal/pgengine/types.go | 2 +- internal/pgengine/yaml.go | 8 ++++---- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/internal/pgengine/pgengine_test.go b/internal/pgengine/pgengine_test.go index 1e4873a9..6e3cf7de 100644 --- a/internal/pgengine/pgengine_test.go +++ b/internal/pgengine/pgengine_test.go @@ -177,6 +177,9 @@ func TestSamplesScripts(t *testing.T) { l := log.Init(config.LoggingOpts{LogLevel: "panic", LogDBLevel: "none"}) pge := container.Engine for _, f := range files { + if f.IsDir() { + continue + } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() assert.NoError(t, pge.ExecuteCustomScripts(ctx, "../../samples/"+f.Name()), diff --git a/internal/pgengine/types.go b/internal/pgengine/types.go index daceea17..cb1bb68e 100644 --- a/internal/pgengine/types.go +++ b/internal/pgengine/types.go @@ -9,7 +9,7 @@ import ( ) type executor interface { - Exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) + Exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) } // Chain structure used to represent tasks chains diff --git a/internal/pgengine/yaml.go b/internal/pgengine/yaml.go index 57f9aee8..45a36fc5 100644 --- a/internal/pgengine/yaml.go +++ b/internal/pgengine/yaml.go @@ -12,7 +12,7 @@ import ( // YamlChain represents a chain with tasks for YAML processing type YamlChain struct { - Chain + Chain `yaml:",inline"` ClientName string `db:"client_name" yaml:"client_name,omitempty"` Schedule string `db:"run_at" yaml:"schedule,omitempty"` Live bool `db:"live" yaml:"live,omitempty"` @@ -21,9 +21,9 @@ type YamlChain struct { // YamlTask extends the basic task structure with Parameters field type YamlTask struct { - ChainTask - TaskName string `db:"task_name" yaml:"name,omitempty"` - Parameters []interface{} `yaml:"parameters,omitempty"` + ChainTask `yaml:",inline"` + TaskName string `db:"task_name" yaml:"name,omitempty"` + Parameters []any `yaml:"parameters,omitempty"` } // YamlConfig represents the root YAML configuration From 688bd4b7b97d2b9ed114105531bd7a9fa699e7b6 Mon Sep 17 00:00:00 2001 From: Pavlo Golub Date: Mon, 29 Sep 2025 18:59:48 +0200 Subject: [PATCH 3/7] simplify LoadYamlChains --- internal/log/log_test.go | 2 +- internal/pgengine/bootstrap.go | 8 +- internal/pgengine/bootstrap_test.go | 12 +- internal/pgengine/log_hook.go | 4 +- internal/pgengine/transaction.go | 2 +- internal/pgengine/yaml.go | 56 +- internal/pgengine/yaml_test.go | 968 +++++++++++++++++++++++++++- 7 files changed, 995 insertions(+), 57 deletions(-) diff --git a/internal/log/log_test.go b/internal/log/log_test.go index 2c33957c..cb4af90f 100644 --- a/internal/log/log_test.go +++ b/internal/log/log_test.go @@ -35,6 +35,6 @@ func TestPgxLog(*testing.T) { pgxl := log.NewPgxLogger(log.Init(config.LoggingOpts{LogLevel: "trace"})) var level tracelog.LogLevel for level = tracelog.LogLevelNone; level <= tracelog.LogLevelTrace; level++ { - pgxl.Log(context.Background(), level, "foo", map[string]interface{}{"func": "TestPgxLog"}) + pgxl.Log(context.Background(), level, "foo", map[string]any{"func": "TestPgxLog"}) } } diff --git a/internal/pgengine/bootstrap.go b/internal/pgengine/bootstrap.go index 8b84ff93..f3141387 100644 --- a/internal/pgengine/bootstrap.go +++ b/internal/pgengine/bootstrap.go @@ -32,9 +32,9 @@ var backoff = retry.WithCappedDuration(maxWaitTime, retry.NewExponential(WaitTim // PgxIface is common interface for every pgx class type PgxIface interface { Begin(ctx context.Context) (pgx.Tx, error) - Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) - QueryRow(context.Context, string, ...interface{}) pgx.Row - Query(ctx context.Context, query string, args ...interface{}) (pgx.Rows, error) + Exec(context.Context, string, ...any) (pgconn.CommandTag, error) + QueryRow(context.Context, string, ...any) pgx.Row + Query(ctx context.Context, query string, args ...any) (pgx.Rows, error) Ping(ctx context.Context) error CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) } @@ -194,7 +194,7 @@ func (pge *PgEngine) AddLogHook(ctx context.Context) { // QueryRowIface specifies interface to use QueryRow method type QueryRowIface interface { - QueryRow(context.Context, string, ...interface{}) pgx.Row + QueryRow(context.Context, string, ...any) pgx.Row } // TryLockClientName obtains lock on the server to prevent another client with the same name diff --git a/internal/pgengine/bootstrap_test.go b/internal/pgengine/bootstrap_test.go index 8ee91a6a..3448ff00 100644 --- a/internal/pgengine/bootstrap_test.go +++ b/internal/pgengine/bootstrap_test.go @@ -78,10 +78,10 @@ func TestFinalizeConnection(t *testing.T) { } type mockpgrow struct { - results []interface{} + results []any } -func (r *mockpgrow) Scan(dest ...interface{}) error { +func (r *mockpgrow) Scan(dest ...any) error { if len(r.results) > 0 { if err, ok := r.results[0].(error); ok { r.results = r.results[1:] @@ -103,7 +103,7 @@ type mockpgconn struct { r pgx.Row } -func (m mockpgconn) QueryRow(context.Context, string, ...interface{}) pgx.Row { +func (m mockpgconn) QueryRow(context.Context, string, ...any) pgx.Row { return m.r } @@ -117,7 +117,7 @@ func TestTryLockClientName(t *testing.T) { }) t.Run("no schema yet", func(t *testing.T) { - r := &mockpgrow{results: []interface{}{ + r := &mockpgrow{results: []any{ 0, //procoid }} m := mockpgconn{r} @@ -125,7 +125,7 @@ func TestTryLockClientName(t *testing.T) { }) t.Run("locking error", func(t *testing.T) { - r := &mockpgrow{results: []interface{}{ + r := &mockpgrow{results: []any{ 1, //procoid errors.New("locking error"), //error }} @@ -134,7 +134,7 @@ func TestTryLockClientName(t *testing.T) { }) t.Run("locking successful", func(t *testing.T) { - r := &mockpgrow{results: []interface{}{ + r := &mockpgrow{results: []any{ 1, //procoid true, //locked }} diff --git a/internal/pgengine/log_hook.go b/internal/pgengine/log_hook.go index 4de87108..67480419 100644 --- a/internal/pgengine/log_hook.go +++ b/internal/pgengine/log_hook.go @@ -140,12 +140,12 @@ func (hook *LogHook) send(cache []logrus.Entry) { pgx.Identifier{"timetable", "log"}, []string{"ts", "client_name", "pid", "log_level", "message", "message_data"}, pgx.CopyFromSlice(len(cache), - func(i int) ([]interface{}, error) { + func(i int) ([]any, error) { jsonData, err := json.Marshal(cache[i].Data) if err != nil { return nil, err } - return []interface{}{cache[i].Time, + return []any{cache[i].Time, hook.client, hook.pid, adaptEntryLevel(cache[i].Level), diff --git a/internal/pgengine/transaction.go b/internal/pgengine/transaction.go index 14814986..50dac9ee 100644 --- a/internal/pgengine/transaction.go +++ b/internal/pgengine/transaction.go @@ -119,7 +119,7 @@ func (pge *PgEngine) ExecAutonomousSQLTask(ctx context.Context, task *ChainTask, // ExecuteSQLCommand executes chain command with parameters inside transaction func (pge *PgEngine) ExecuteSQLCommand(ctx context.Context, executor executor, command string, paramValues []string) (out string, err error) { var ct pgconn.CommandTag - var params []interface{} + var params []any if strings.TrimSpace(command) == "" { return "", errors.New("SQL command cannot be empty") } diff --git a/internal/pgengine/yaml.go b/internal/pgengine/yaml.go index 45a36fc5..9c838959 100644 --- a/internal/pgengine/yaml.go +++ b/internal/pgengine/yaml.go @@ -58,42 +58,12 @@ func (pge *PgEngine) LoadYamlChains(ctx context.Context, filePath string, replac return fmt.Errorf("chain '%s' already exists (use --replace flag to overwrite)", yamlChain.ChainName) } - // Use existing add_job function for single-task chains - if len(yamlChain.Tasks) == 1 { - task := yamlChain.Tasks[0] - params, _ := task.ToSQLParameters() - var paramsValue interface{} - if params == "" { - paramsValue = nil - } else { - paramsValue = params - } - - _, err := pge.ConfigDb.Exec(ctx, ` - SELECT timetable.add_job($1, $2, $3, $4::jsonb, $5::timetable.command_kind, $6, $7, $8, $9, $10, $11, $12)`, - yamlChain.ChainName, - yamlChain.Schedule, - task.Command, - paramsValue, - task.Kind, - nullString(yamlChain.ClientName), - yamlChain.MaxInstances, - yamlChain.Live, - yamlChain.SelfDestruct, - task.IgnoreError, - yamlChain.ExclusiveExecution, - nullString(yamlChain.OnError)) - if err != nil { - return fmt.Errorf("failed to create chain %s: %w", yamlChain.ChainName, err) - } - } else { - // Multi-task chain - use direct SQL - chainID, err := pge.createChainFromYaml(ctx, &yamlChain) - if err != nil { - return fmt.Errorf("failed to create multi-task chain %s: %w", yamlChain.ChainName, err) - } - pge.l.WithField("chain", yamlChain.ChainName).WithField("chain_id", chainID).Info("Created multi-task chain") + // Multi-task chain - use direct SQL + chainID, err := pge.createChainFromYaml(ctx, &yamlChain) + if err != nil { + return fmt.Errorf("failed to create multi-task chain %s: %w", yamlChain.ChainName, err) } + pge.l.WithField("chain", yamlChain.ChainName).WithField("chain_id", chainID).Info("Created multi-task chain") } pge.l.WithField("chains", len(yamlConfig.Chains)).WithField("file", filePath).Info("Successfully imported YAML chains") @@ -167,7 +137,7 @@ func (pge *PgEngine) createChainFromYaml(ctx context.Context, yamlChain *YamlCha } // nullString returns nil for empty strings, otherwise returns the string -func nullString(s string) interface{} { +func nullString(s string) any { if s == "" { return nil } @@ -185,10 +155,16 @@ func (c *YamlChain) ValidateChain() error { } // Validate cron format - switch c.Schedule { - case "", "@reboot", "@after", "@every": - // Valid special schedules - default: + specialSchedules := []string{"@reboot", "@after", "@every"} + isSpecial := false + for _, s := range specialSchedules { + if strings.HasPrefix(c.Schedule, s) { + isSpecial = true + break + } + } + + if !isSpecial { fields := strings.Fields(c.Schedule) if len(fields) != 5 { return fmt.Errorf("invalid cron format: %s (expected 5 fields)", c.Schedule) diff --git a/internal/pgengine/yaml_test.go b/internal/pgengine/yaml_test.go index 76f9c443..ad566376 100644 --- a/internal/pgengine/yaml_test.go +++ b/internal/pgengine/yaml_test.go @@ -5,10 +5,11 @@ import ( "os" "testing" - "github.com/cybertec-postgresql/pg_timetable/internal/pgengine" - "github.com/cybertec-postgresql/pg_timetable/internal/testutils" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/cybertec-postgresql/pg_timetable/internal/pgengine" + "github.com/cybertec-postgresql/pg_timetable/internal/testutils" ) // Helper function to create temporary YAML file @@ -165,7 +166,7 @@ func TestYamlParameterHandling(t *testing.T) { sqlTask := chain.Tasks[0] require.Equal(t, "SQL", sqlTask.Kind) require.Equal(t, 2, len(sqlTask.Parameters)) - sqlParam1, ok := sqlTask.Parameters[0].([]interface{}) + sqlParam1, ok := sqlTask.Parameters[0].([]any) require.True(t, ok, "SQL parameter should be an array") assert.Equal(t, 3, len(sqlParam1)) @@ -192,3 +193,964 @@ func TestYamlParameterHandling(t *testing.T) { require.True(t, ok, "Log parameter can be a string") assert.Equal(t, "warning message", logParam1) } + +func TestParseYamlFile(t *testing.T) { + t.Run("Valid YAML file", func(t *testing.T) { + yamlContent := `chains: + - name: "test-chain" + schedule: "0 * * * *" + live: true + max_instances: 2 + timeout: 300 + self_destruct: true + exclusive: true + client_name: "test-client" + on_error: "RETRY" + tasks: + - name: "test-task" + kind: "SQL" + command: "SELECT 1" + parameters: ["param1", 42, true] + ignore_error: false + autonomous: false + timeout: 60 + run_as: "postgres" + connect_string: "test-db"` + + tmpfile := createTempYamlFile(t, yamlContent) + defer removeTempFile(t, tmpfile) + + config, err := pgengine.ParseYamlFile(tmpfile) + require.NoError(t, err) + require.Len(t, config.Chains, 1) + + chain := config.Chains[0] + assert.Equal(t, "test-chain", chain.ChainName) + assert.Equal(t, "0 * * * *", chain.Schedule) + assert.True(t, chain.Live) + assert.Equal(t, 2, chain.MaxInstances) + assert.Equal(t, 300, chain.Timeout) + assert.True(t, chain.SelfDestruct) + assert.True(t, chain.ExclusiveExecution) + assert.Equal(t, "test-client", chain.ClientName) + assert.Equal(t, "RETRY", chain.OnError) + + require.Len(t, chain.Tasks, 1) + task := chain.Tasks[0] + assert.Equal(t, "test-task", task.TaskName) + assert.Equal(t, "SQL", task.Kind) + assert.Equal(t, "SELECT 1", task.Command) + assert.False(t, task.IgnoreError) + assert.False(t, task.Autonomous) + assert.Equal(t, 60, task.Timeout) + assert.Equal(t, "postgres", task.RunAs) + assert.Equal(t, "test-db", task.ConnectString) + }) + + t.Run("File not found", func(t *testing.T) { + _, err := pgengine.ParseYamlFile("/non/existent/file.yaml") + assert.Error(t, err) + assert.Contains(t, err.Error(), "file not found") + }) + + t.Run("Invalid file extension", func(t *testing.T) { + tmpfile, err := os.CreateTemp("", "test-*.txt") + require.NoError(t, err) + tmpfileName := tmpfile.Name() + defer os.Remove(tmpfileName) + tmpfile.Close() + + _, err = pgengine.ParseYamlFile(tmpfileName) + assert.Error(t, err) + assert.Contains(t, err.Error(), "file must have .yaml or .yml extension") + }) + + t.Run("Invalid YAML syntax", func(t *testing.T) { + invalidYaml := `chains: + - name: "test" + schedule: "* * * * *" + tasks: + - name: "task1" + kind: "SQL + command: SELECT 1` + + tmpfile := createTempYamlFile(t, invalidYaml) + defer removeTempFile(t, tmpfile) + + _, err := pgengine.ParseYamlFile(tmpfile) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse YAML") + }) + + t.Run("Validation errors", func(t *testing.T) { + invalidChain := `chains: + - name: "" + schedule: "* * * * *" + tasks: + - command: "SELECT 1"` + + tmpfile := createTempYamlFile(t, invalidChain) + defer removeTempFile(t, tmpfile) + + _, err := pgengine.ParseYamlFile(tmpfile) + assert.Error(t, err) + assert.Contains(t, err.Error(), "chain name is required") + }) +} + +func TestYamlChainValidation(t *testing.T) { + t.Run("Valid chain", func(t *testing.T) { + chain := &pgengine.YamlChain{ + Chain: pgengine.Chain{ + ChainName: "test-chain", + }, + Schedule: "0 * * * *", + Tasks: []pgengine.YamlTask{ + { + ChainTask: pgengine.ChainTask{ + Command: "SELECT 1", + Kind: "SQL", + }, + }, + }, + } + + err := chain.ValidateChain() + assert.NoError(t, err) + }) + + t.Run("Missing chain name", func(t *testing.T) { + chain := &pgengine.YamlChain{ + Schedule: "0 * * * *", + Tasks: []pgengine.YamlTask{ + { + ChainTask: pgengine.ChainTask{ + Command: "SELECT 1", + Kind: "SQL", + }, + }, + }, + } + + err := chain.ValidateChain() + assert.Error(t, err) + assert.Contains(t, err.Error(), "chain name is required") + }) + + t.Run("Missing schedule", func(t *testing.T) { + chain := &pgengine.YamlChain{ + Chain: pgengine.Chain{ + ChainName: "test-chain", + }, + Tasks: []pgengine.YamlTask{ + { + ChainTask: pgengine.ChainTask{ + Command: "SELECT 1", + Kind: "SQL", + }, + }, + }, + } + + err := chain.ValidateChain() + assert.Error(t, err) + assert.Contains(t, err.Error(), "chain schedule is required") + }) + + t.Run("Invalid cron format", func(t *testing.T) { + chain := &pgengine.YamlChain{ + Chain: pgengine.Chain{ + ChainName: "test-chain", + }, + Schedule: "invalid cron", + Tasks: []pgengine.YamlTask{ + { + ChainTask: pgengine.ChainTask{ + Command: "SELECT 1", + Kind: "SQL", + }, + }, + }, + } + + err := chain.ValidateChain() + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid cron format") + }) + + t.Run("Special schedules", func(t *testing.T) { + specialSchedules := []string{"@reboot", "@after", "@every"} + for _, schedule := range specialSchedules { + chain := &pgengine.YamlChain{ + Chain: pgengine.Chain{ + ChainName: "test-chain", + }, + Schedule: schedule, + Tasks: []pgengine.YamlTask{ + { + ChainTask: pgengine.ChainTask{ + Command: "SELECT 1", + Kind: "SQL", + }, + }, + }, + } + + err := chain.ValidateChain() + assert.NoError(t, err, "Schedule %s should be valid", schedule) + } + }) + + t.Run("No tasks", func(t *testing.T) { + chain := &pgengine.YamlChain{ + Chain: pgengine.Chain{ + ChainName: "test-chain", + }, + Schedule: "0 * * * *", + Tasks: []pgengine.YamlTask{}, + } + + err := chain.ValidateChain() + assert.Error(t, err) + assert.Contains(t, err.Error(), "chain must have at least one task") + }) + + t.Run("Task validation error", func(t *testing.T) { + chain := &pgengine.YamlChain{ + Chain: pgengine.Chain{ + ChainName: "test-chain", + }, + Schedule: "0 * * * *", + Tasks: []pgengine.YamlTask{ + { + ChainTask: pgengine.ChainTask{ + Command: "", // Invalid empty command + Kind: "SQL", + }, + }, + }, + } + + err := chain.ValidateChain() + assert.Error(t, err) + assert.Contains(t, err.Error(), "task 1:") + assert.Contains(t, err.Error(), "task command is required") + }) +} + +func TestYamlTaskValidation(t *testing.T) { + t.Run("Valid task", func(t *testing.T) { + task := &pgengine.YamlTask{ + ChainTask: pgengine.ChainTask{ + Command: "SELECT 1", + Kind: "SQL", + Timeout: 60, + }, + } + + err := task.ValidateTask() + assert.NoError(t, err) + }) + + t.Run("Missing command", func(t *testing.T) { + task := &pgengine.YamlTask{ + ChainTask: pgengine.ChainTask{ + Kind: "SQL", + }, + } + + err := task.ValidateTask() + assert.Error(t, err) + assert.Contains(t, err.Error(), "task command is required") + }) + + t.Run("Valid kinds", func(t *testing.T) { + validKinds := []string{"", "SQL", "PROGRAM", "BUILTIN", "sql", "program", "builtin"} + for _, kind := range validKinds { + task := &pgengine.YamlTask{ + ChainTask: pgengine.ChainTask{ + Command: "SELECT 1", + Kind: kind, + }, + } + + err := task.ValidateTask() + assert.NoError(t, err, "Kind %s should be valid", kind) + } + }) + + t.Run("Invalid kind", func(t *testing.T) { + task := &pgengine.YamlTask{ + ChainTask: pgengine.ChainTask{ + Command: "SELECT 1", + Kind: "INVALID", + }, + } + + err := task.ValidateTask() + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid task kind: INVALID") + }) + + t.Run("Negative timeout", func(t *testing.T) { + task := &pgengine.YamlTask{ + ChainTask: pgengine.ChainTask{ + Command: "SELECT 1", + Kind: "SQL", + Timeout: -1, + }, + } + + err := task.ValidateTask() + assert.Error(t, err) + assert.Contains(t, err.Error(), "task timeout must be non-negative") + }) +} + +func TestYamlChainSetDefaults(t *testing.T) { + t.Run("Set default schedule", func(t *testing.T) { + chain := &pgengine.YamlChain{ + Chain: pgengine.Chain{ + ChainName: "test-chain", + }, + Tasks: []pgengine.YamlTask{ + { + ChainTask: pgengine.ChainTask{ + Command: "SELECT 1", + }, + }, + }, + } + + chain.SetDefaults() + assert.Equal(t, "* * * * *", chain.Schedule) + assert.Equal(t, "SQL", chain.Tasks[0].Kind) + }) + + t.Run("Keep existing values", func(t *testing.T) { + chain := &pgengine.YamlChain{ + Chain: pgengine.Chain{ + ChainName: "test-chain", + }, + Schedule: "0 0 * * *", + Tasks: []pgengine.YamlTask{ + { + ChainTask: pgengine.ChainTask{ + Command: "echo hello", + Kind: "PROGRAM", + }, + }, + }, + } + + chain.SetDefaults() + assert.Equal(t, "0 0 * * *", chain.Schedule) + assert.Equal(t, "PROGRAM", chain.Tasks[0].Kind) + }) +} + +func TestToSQLParameters(t *testing.T) { + t.Run("No parameters", func(t *testing.T) { + task := &pgengine.YamlTask{} + result, err := task.ToSQLParameters() + assert.NoError(t, err) + assert.Equal(t, "", result) + }) + + t.Run("String parameters", func(t *testing.T) { + task := &pgengine.YamlTask{ + Parameters: []any{"hello", "world"}, + } + result, err := task.ToSQLParameters() + assert.NoError(t, err) + assert.Equal(t, `["hello", "world"]`, result) + }) + + t.Run("String with quotes", func(t *testing.T) { + task := &pgengine.YamlTask{ + Parameters: []any{`hello "quoted" world`}, + } + result, err := task.ToSQLParameters() + assert.NoError(t, err) + assert.Equal(t, `["hello \"quoted\" world"]`, result) + }) + + t.Run("Integer parameters", func(t *testing.T) { + task := &pgengine.YamlTask{ + Parameters: []any{42, int32(100), int64(200)}, + } + result, err := task.ToSQLParameters() + assert.NoError(t, err) + assert.Equal(t, `[42, 100, 200]`, result) + }) + + t.Run("Float parameters", func(t *testing.T) { + task := &pgengine.YamlTask{ + Parameters: []any{3.14, float32(2.71)}, + } + result, err := task.ToSQLParameters() + assert.NoError(t, err) + assert.Equal(t, `[3.14, 2.71]`, result) + }) + + t.Run("Boolean parameters", func(t *testing.T) { + task := &pgengine.YamlTask{ + Parameters: []any{true, false}, + } + result, err := task.ToSQLParameters() + assert.NoError(t, err) + assert.Equal(t, `[true, false]`, result) + }) + + t.Run("Mixed parameter types", func(t *testing.T) { + task := &pgengine.YamlTask{ + Parameters: []any{"text", 42, 3.14, true, nil}, + } + result, err := task.ToSQLParameters() + assert.NoError(t, err) + assert.Equal(t, `["text", 42, 3.14, true, ""]`, result) + }) +} + +func TestNullString(t *testing.T) { + // Note: nullString is not exported, so we test it indirectly through chain creation + t.Run("Indirect test via chain creation", func(t *testing.T) { + container, cleanup := testutils.SetupPostgresContainer(t) + defer cleanup() + + ctx := context.Background() + pge := container.Engine + + yamlContent := `chains: + - name: "test-null-strings" + schedule: "0 0 * * *" + client_name: "" # Should become NULL in database + on_error: "" # Should become NULL in database + tasks: + - name: "test-task" + command: "SELECT 1" + kind: "SQL" + run_as: "" # Should become NULL + database_connection: "" # Should become NULL` + + tmpfile := createTempYamlFile(t, yamlContent) + defer removeTempFile(t, tmpfile) + + err := pge.LoadYamlChains(ctx, tmpfile, false) + require.NoError(t, err) + + // Verify NULL values in database + var clientName, onError any + err = pge.ConfigDb.QueryRow(ctx, + "SELECT client_name, on_error FROM timetable.chain WHERE chain_name = $1", + "test-null-strings").Scan(&clientName, &onError) + require.NoError(t, err) + assert.Nil(t, clientName) + assert.Nil(t, onError) + }) +} + +func TestLoadYamlChainsMultiTask(t *testing.T) { + container, cleanup := testutils.SetupPostgresContainer(t) + defer cleanup() + + ctx := context.Background() + pge := container.Engine + + t.Run("Multi-task chain creation", func(t *testing.T) { + yamlContent := `chains: + - name: "multi-task-chain" + schedule: "0 0 * * *" + live: true + max_instances: 2 + timeout: 300 + self_destruct: false + exclusive: true + client_name: "test-client" + on_error: "CONTINUE" + tasks: + - name: "first-task" + kind: "SQL" + command: "SELECT 1" + ignore_error: false + autonomous: false + timeout: 60 + run_as: "postgres" + database_connection: "main" + parameters: ["param1", 42] + - name: "second-task" + kind: "PROGRAM" + command: "echo" + ignore_error: true + autonomous: true + timeout: 30 + parameters: ["hello", "world"]` + + tmpfile := createTempYamlFile(t, yamlContent) + defer removeTempFile(t, tmpfile) + + err := pge.LoadYamlChains(ctx, tmpfile, false) + require.NoError(t, err) + + // Verify chain was created + var chainID int64 + err = pge.ConfigDb.QueryRow(ctx, + "SELECT chain_id FROM timetable.chain WHERE chain_name = $1", + "multi-task-chain").Scan(&chainID) + require.NoError(t, err) + assert.Greater(t, chainID, int64(0)) + + // Verify tasks were created + var taskCount int + err = pge.ConfigDb.QueryRow(ctx, + "SELECT COUNT(*) FROM timetable.task WHERE chain_id = $1", + chainID).Scan(&taskCount) + require.NoError(t, err) + assert.Equal(t, 2, taskCount) + + // Verify task parameters were created + var paramCount int + err = pge.ConfigDb.QueryRow(ctx, + "SELECT COUNT(*) FROM timetable.parameter p JOIN timetable.task t ON p.task_id = t.task_id WHERE t.chain_id = $1", + chainID).Scan(¶mCount) + require.NoError(t, err) + assert.Equal(t, 2, paramCount) // 2 tasks with parameters + }) + + t.Run("Chain already exists without replace", func(t *testing.T) { + yamlContent := `chains: + - name: "existing-chain" + schedule: "0 0 * * *" + tasks: + - command: "SELECT 1" + kind: "SQL"` + + tmpfile := createTempYamlFile(t, yamlContent) + defer removeTempFile(t, tmpfile) + + // First load should succeed + err := pge.LoadYamlChains(ctx, tmpfile, false) + require.NoError(t, err) + + // Second load without replace should fail + err = pge.LoadYamlChains(ctx, tmpfile, false) + assert.Error(t, err) + assert.Contains(t, err.Error(), "already exists") + }) + + t.Run("Database error during chain creation", func(t *testing.T) { + // Test with invalid schedule that fails YAML validation + yamlContent := `chains: + - name: "invalid-schedule-chain" + schedule: "invalid cron expression that passes validation but fails in DB" + tasks: + - command: "SELECT 1" + kind: "SQL"` + + tmpfile := createTempYamlFile(t, yamlContent) + defer removeTempFile(t, tmpfile) + + // This should fail at YAML validation level due to invalid cron format + err := pge.LoadYamlChains(ctx, tmpfile, false) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse YAML file") + }) +} + +func TestLoadYamlChainsErrorCases(t *testing.T) { + container, cleanup := testutils.SetupPostgresContainer(t) + defer cleanup() + + ctx := context.Background() + pge := container.Engine + + t.Run("Invalid YAML file", func(t *testing.T) { + err := pge.LoadYamlChains(ctx, "/non/existent/file.yaml", false) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse YAML file") + }) + + t.Run("Invalid YAML content", func(t *testing.T) { + invalidYaml := `chains: + - name: "" # Invalid: empty name + schedule: "0 0 * * *" + tasks: + - command: "SELECT 1"` + + tmpfile := createTempYamlFile(t, invalidYaml) + defer removeTempFile(t, tmpfile) + + err := pge.LoadYamlChains(ctx, tmpfile, false) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse YAML file") + }) +} + +func TestCreateChainFromYamlEdgeCases(t *testing.T) { + container, cleanup := testutils.SetupPostgresContainer(t) + defer cleanup() + + ctx := context.Background() + pge := container.Engine + + t.Run("Task with no parameters", func(t *testing.T) { + yamlContent := `chains: + - name: "no-params-chain" + schedule: "0 0 * * *" + tasks: + - name: "no-param-task" + kind: "SQL" + command: "SELECT CURRENT_TIMESTAMP" + - name: "empty-param-task" + kind: "SQL" + command: "SELECT 1" + parameters: []` + + tmpfile := createTempYamlFile(t, yamlContent) + defer removeTempFile(t, tmpfile) + + err := pge.LoadYamlChains(ctx, tmpfile, false) + require.NoError(t, err) + + // Verify no parameters were inserted + var paramCount int + err = pge.ConfigDb.QueryRow(ctx, + `SELECT COUNT(*) FROM timetable.parameter p + JOIN timetable.task t ON p.task_id = t.task_id + JOIN timetable.chain c ON t.chain_id = c.chain_id + WHERE c.chain_name = $1`, + "no-params-chain").Scan(¶mCount) + require.NoError(t, err) + assert.Equal(t, 0, paramCount) + }) + + t.Run("Complex parameter types", func(t *testing.T) { + yamlContent := `chains: + - name: "complex-params-chain" + schedule: "0 0 * * *" + tasks: + - name: "complex-task" + kind: "SQL" + command: "SELECT $1::jsonb" + parameters: + - {"key": "value", "number": 42, "nested": {"inner": true}}` + + tmpfile := createTempYamlFile(t, yamlContent) + defer removeTempFile(t, tmpfile) + + err := pge.LoadYamlChains(ctx, tmpfile, false) + require.NoError(t, err) + + // Verify parameter was stored + var paramValue string + err = pge.ConfigDb.QueryRow(ctx, + `SELECT p.value FROM timetable.parameter p + JOIN timetable.task t ON p.task_id = t.task_id + JOIN timetable.chain c ON t.chain_id = c.chain_id + WHERE c.chain_name = $1`, + "complex-params-chain").Scan(¶mValue) + require.NoError(t, err) + assert.Contains(t, paramValue, "key") + assert.Contains(t, paramValue, "value") + }) +} + +func TestToSQLParametersErrorHandling(t *testing.T) { + t.Run("Error in parameter conversion", func(t *testing.T) { + task := &pgengine.YamlTask{ + Parameters: []any{ + make(chan int), // unsupported type that can't be converted + }, + } + result, err := task.ToSQLParameters() + assert.NoError(t, err) // Function doesn't actually return errors, just converts to string + assert.Contains(t, result, "0x") // channel representation + }) + + t.Run("Empty array parameters", func(t *testing.T) { + task := &pgengine.YamlTask{ + Parameters: []any{[]any{}}, + } + result, err := task.ToSQLParameters() + assert.NoError(t, err) + assert.Equal(t, `["[]"]`, result) + }) + + t.Run("Complex nested structures", func(t *testing.T) { + task := &pgengine.YamlTask{ + Parameters: []any{ + map[string]any{ + "nested": map[string]any{ + "deep": []any{1, 2, 3}, + }, + }, + }, + } + result, err := task.ToSQLParameters() + assert.NoError(t, err) + assert.Contains(t, result, "nested") + }) +} + +func TestCreateChainFromYamlErrorHandling(t *testing.T) { + container, cleanup := testutils.SetupPostgresContainer(t) + defer cleanup() + + ctx := context.Background() + pge := container.Engine + + t.Run("Multi-task chain with invalid parameter conversion", func(t *testing.T) { + yamlContent := `chains: + - name: "param-error-chain" + schedule: "0 0 * * *" + tasks: + - name: "first-task" + kind: "SQL" + command: "SELECT 1" + parameters: [{"invalid": {"deeply": {"nested": "value"}}}] + - name: "second-task" + kind: "SQL" + command: "SELECT 2"` + + tmpfile := createTempYamlFile(t, yamlContent) + defer removeTempFile(t, tmpfile) + + err := pge.LoadYamlChains(ctx, tmpfile, false) + // Should succeed even with complex parameters + require.NoError(t, err) + + // Verify chain and tasks were created + var chainID int64 + err = pge.ConfigDb.QueryRow(ctx, + "SELECT chain_id FROM timetable.chain WHERE chain_name = $1", + "param-error-chain").Scan(&chainID) + require.NoError(t, err) + + var taskCount int + err = pge.ConfigDb.QueryRow(ctx, + "SELECT COUNT(*) FROM timetable.task WHERE chain_id = $1", + chainID).Scan(&taskCount) + require.NoError(t, err) + assert.Equal(t, 2, taskCount) + }) + + t.Run("Multi-task chain with various field types", func(t *testing.T) { + yamlContent := `chains: + - name: "comprehensive-multi-task" + schedule: "@every 1h" + live: false + max_instances: 3 + timeout: 600 + self_destruct: true + exclusive: false + client_name: "test-client-multi" + on_error: "IGNORE" + tasks: + - name: "sql-task" + kind: "SQL" + command: "SELECT $1, $2" + parameters: ["string", 123] + ignore_error: true + autonomous: false + timeout: 120 + run_as: "test_user" + connect_string: "dbname=test" + - name: "program-task" + kind: "PROGRAM" + command: "echo" + parameters: ["hello", "world"] + ignore_error: false + autonomous: true + timeout: 60 + - name: "builtin-task" + kind: "BUILTIN" + command: "Sleep" + parameters: [5] + ignore_error: false + autonomous: false + timeout: 10` + + tmpfile := createTempYamlFile(t, yamlContent) + defer removeTempFile(t, tmpfile) + + err := pge.LoadYamlChains(ctx, tmpfile, false) + require.NoError(t, err) + + // Verify all chain properties + var chainID int64 + var schedule, clientName, onError string + var live, selfDestruct, exclusive bool + var maxInstances, timeout int + err = pge.ConfigDb.QueryRow(ctx, + `SELECT chain_id, run_at, client_name, on_error, live, + self_destruct, exclusive_execution, max_instances, timeout + FROM timetable.chain WHERE chain_name = $1`, + "comprehensive-multi-task").Scan( + &chainID, &schedule, &clientName, &onError, &live, + &selfDestruct, &exclusive, &maxInstances, &timeout) + require.NoError(t, err) + + assert.Equal(t, "@every 1h", schedule) + assert.Equal(t, "test-client-multi", clientName) + assert.Equal(t, "IGNORE", onError) + assert.False(t, live) + assert.True(t, selfDestruct) + assert.False(t, exclusive) + assert.Equal(t, 3, maxInstances) + assert.Equal(t, 600, timeout) + + // Verify all tasks were created with correct properties + rows, err := pge.ConfigDb.Query(ctx, + `SELECT task_name, kind, command, ignore_error, autonomous, + timeout, run_as, database_connection + FROM timetable.task WHERE chain_id = $1 ORDER BY task_order`, + chainID) + require.NoError(t, err) + defer rows.Close() + + expectedTasks := []struct { + name string + kind string + command string + ignoreErr bool + auto bool + timeout int + runAs *string + dbConn *string + }{ + {"sql-task", "SQL", "SELECT $1, $2", true, false, 120, stringPtr("test_user"), stringPtr("dbname=test")}, + {"program-task", "PROGRAM", "echo", false, true, 60, nil, nil}, + {"builtin-task", "BUILTIN", "Sleep", false, false, 10, nil, nil}, + } + + taskIdx := 0 + for rows.Next() { + var name, kind, command string + var ignoreErr, auto bool + var timeout int + var runAs, dbConn *string + + err := rows.Scan(&name, &kind, &command, &ignoreErr, &auto, &timeout, &runAs, &dbConn) + require.NoError(t, err) + + expected := expectedTasks[taskIdx] + assert.Equal(t, expected.name, name) + assert.Equal(t, expected.kind, kind) + assert.Equal(t, expected.command, command) + assert.Equal(t, expected.ignoreErr, ignoreErr) + assert.Equal(t, expected.auto, auto) + assert.Equal(t, expected.timeout, timeout) + assert.Equal(t, expected.runAs, runAs) + assert.Equal(t, expected.dbConn, dbConn) + + taskIdx++ + } + assert.Equal(t, 3, taskIdx) + + // Verify parameters were stored correctly + var paramCount int + err = pge.ConfigDb.QueryRow(ctx, + `SELECT COUNT(*) FROM timetable.parameter p + JOIN timetable.task t ON p.task_id = t.task_id + WHERE t.chain_id = $1`, + chainID).Scan(¶mCount) + require.NoError(t, err) + assert.Equal(t, 3, paramCount) // 3 tasks with parameters + }) +} + +// Helper function to create string pointer +func stringPtr(s string) *string { + return &s +} + +func TestNullStringFunction(t *testing.T) { + // Testing nullString indirectly through database operations + container, cleanup := testutils.SetupPostgresContainer(t) + defer cleanup() + + ctx := context.Background() + pge := container.Engine + + t.Run("All null fields", func(t *testing.T) { + yamlContent := `chains: + - name: "all-nulls-chain" + schedule: "0 0 * * *" + # client_name: "" # Should be NULL + # on_error: "" # Should be NULL + tasks: + - name: "null-task" + command: "SELECT 1" + kind: "SQL" + # run_as: "" # Should be NULL + # connect_string: "" # Should be NULL` + + tmpfile := createTempYamlFile(t, yamlContent) + defer removeTempFile(t, tmpfile) + + err := pge.LoadYamlChains(ctx, tmpfile, false) + require.NoError(t, err) + + // Verify NULL values in chain table + var clientName, onError any + err = pge.ConfigDb.QueryRow(ctx, + "SELECT client_name, on_error FROM timetable.chain WHERE chain_name = $1", + "all-nulls-chain").Scan(&clientName, &onError) + require.NoError(t, err) + assert.Nil(t, clientName) + assert.Nil(t, onError) + + // Verify NULL values in task table + var runAs, dbConn any + err = pge.ConfigDb.QueryRow(ctx, + `SELECT run_as, database_connection FROM timetable.task t + JOIN timetable.chain c ON t.chain_id = c.chain_id + WHERE c.chain_name = $1`, + "all-nulls-chain").Scan(&runAs, &dbConn) + require.NoError(t, err) + assert.Nil(t, runAs) + assert.Nil(t, dbConn) + }) + + t.Run("Mixed null and non-null fields", func(t *testing.T) { + yamlContent := `chains: + - name: "mixed-nulls-chain" + schedule: "0 0 * * *" + client_name: "real-client" + # on_error not specified - should be NULL + tasks: + - name: "mixed-task" + command: "SELECT 1" + kind: "SQL" + run_as: "real-user" + # connect_string not specified - should be NULL` + + tmpfile := createTempYamlFile(t, yamlContent) + defer removeTempFile(t, tmpfile) + + err := pge.LoadYamlChains(ctx, tmpfile, false) + require.NoError(t, err) + + // Verify mixed values in chain table + var clientName, onError any + err = pge.ConfigDb.QueryRow(ctx, + "SELECT client_name, on_error FROM timetable.chain WHERE chain_name = $1", + "mixed-nulls-chain").Scan(&clientName, &onError) + require.NoError(t, err) + require.NotNil(t, clientName) + assert.Equal(t, "real-client", clientName) + assert.Nil(t, onError) + + // Verify mixed values in task table + var runAs, dbConn any + err = pge.ConfigDb.QueryRow(ctx, + `SELECT run_as, database_connection FROM timetable.task t + JOIN timetable.chain c ON t.chain_id = c.chain_id + WHERE c.chain_name = $1`, + "mixed-nulls-chain").Scan(&runAs, &dbConn) + require.NoError(t, err) + require.NotNil(t, runAs) + assert.Equal(t, "real-user", runAs) + assert.Nil(t, dbConn) + }) +} From aa2cf0f0ab6d34848b5da42ac8547a8d13de1f6b Mon Sep 17 00:00:00 2001 From: Pavlo Golub Date: Mon, 29 Sep 2025 20:03:44 +0200 Subject: [PATCH 4/7] add TestExecuteFileScript --- internal/pgengine/bootstrap_test.go | 301 ++++++++++++++++++++++++++++ internal/pgengine/yaml.go | 3 +- 2 files changed, 302 insertions(+), 2 deletions(-) diff --git a/internal/pgengine/bootstrap_test.go b/internal/pgengine/bootstrap_test.go index 3448ff00..3cae9f46 100644 --- a/internal/pgengine/bootstrap_test.go +++ b/internal/pgengine/bootstrap_test.go @@ -4,10 +4,13 @@ import ( "context" "errors" "fmt" + "os" + "path/filepath" "reflect" "testing" "time" + "github.com/cybertec-postgresql/pg_timetable/internal/config" "github.com/cybertec-postgresql/pg_timetable/internal/pgengine" pgx "github.com/jackc/pgx/v5" "github.com/pashagolub/pgxmock/v4" @@ -142,3 +145,301 @@ func TestTryLockClientName(t *testing.T) { assert.NoError(t, pge.TryLockClientName(context.Background(), m)) }) } + +func TestExecuteFileScript(t *testing.T) { + initmockdb(t) + defer mockPool.Close() + mockpge := pgengine.NewDB(mockPool, "pgengine_unit_test") + + // Create temporary directory for test files + tmpDir := t.TempDir() + + t.Run("SQL file execution", func(t *testing.T) { + // Create temporary SQL file + sqlFile := filepath.Join(tmpDir, "test.sql") + err := os.WriteFile(sqlFile, []byte("SELECT 1;"), 0644) + assert.NoError(t, err) + + // Mock the SQL execution + mockPool.ExpectExec("SELECT 1;").WillReturnResult(pgxmock.NewResult("SELECT", 1)) + + cmdOpts := config.CmdOptions{} + cmdOpts.Start.File = sqlFile + + err = mockpge.ExecuteFileScript(context.Background(), cmdOpts) + assert.NoError(t, err) + }) + + t.Run("SQL file execution error", func(t *testing.T) { + // Create temporary SQL file + sqlFile := filepath.Join(tmpDir, "test_error.sql") + err := os.WriteFile(sqlFile, []byte("SELECT 1;"), 0644) + assert.NoError(t, err) + + // Mock the SQL execution with error + mockPool.ExpectExec("SELECT 1;").WillReturnError(errors.New("SQL execution failed")) + + cmdOpts := config.CmdOptions{} + cmdOpts.Start.File = sqlFile + + err = mockpge.ExecuteFileScript(context.Background(), cmdOpts) + assert.Error(t, err) + }) + + t.Run("YAML file validation mode - valid file", func(t *testing.T) { + // Create temporary YAML file + yamlFile := filepath.Join(tmpDir, "test.yaml") + yamlContent := `chains: + - name: test_chain + tasks: + - name: test_task + command: SELECT 1` + err := os.WriteFile(yamlFile, []byte(yamlContent), 0644) + assert.NoError(t, err) + + cmdOpts := config.CmdOptions{} + cmdOpts.Start.File = yamlFile + cmdOpts.Start.Validate = true + + err = mockpge.ExecuteFileScript(context.Background(), cmdOpts) + assert.NoError(t, err) + }) + + t.Run("YAML file validation mode - invalid file", func(t *testing.T) { + // Create temporary YAML file with invalid content + yamlFile := filepath.Join(tmpDir, "invalid.yaml") + yamlContent := `chains: + - name: test_chain + invalid_field: value + - malformed yaml` + err := os.WriteFile(yamlFile, []byte(yamlContent), 0644) + assert.NoError(t, err) + + cmdOpts := config.CmdOptions{} + cmdOpts.Start.File = yamlFile + cmdOpts.Start.Validate = true + + err = mockpge.ExecuteFileScript(context.Background(), cmdOpts) + // Expect error due to invalid YAML structure + assert.Error(t, err) + }) + + t.Run("YAML file import mode", func(t *testing.T) { + // Create temporary YAML file + yamlFile := filepath.Join(tmpDir, "test_import.yaml") + yamlContent := `chains: + - name: test_chain + tasks: + - name: test_task + command: SELECT 1` + err := os.WriteFile(yamlFile, []byte(yamlContent), 0644) + assert.NoError(t, err) + + cmdOpts := config.CmdOptions{} + cmdOpts.Start.File = yamlFile + cmdOpts.Start.Validate = false + cmdOpts.Start.Replace = false + + err = mockpge.ExecuteFileScript(context.Background(), cmdOpts) + assert.NoError(t, err) + }) + + t.Run("YML file extension", func(t *testing.T) { + // Create temporary YML file + ymlFile := filepath.Join(tmpDir, "test.yml") + yamlContent := `chains: + - name: test_chain + tasks: + - name: test_task + command: SELECT 1` + err := os.WriteFile(ymlFile, []byte(yamlContent), 0644) + assert.NoError(t, err) + + cmdOpts := config.CmdOptions{} + cmdOpts.Start.File = ymlFile + cmdOpts.Start.Validate = true + + err = mockpge.ExecuteFileScript(context.Background(), cmdOpts) + assert.NoError(t, err) + }) + + t.Run("File without extension - YAML content", func(t *testing.T) { + // Create file without extension with YAML content + noExtFile := filepath.Join(tmpDir, "test_no_ext") + yamlContent := `chains: + - name: test_chain + tasks: + - name: test_task + command: SELECT 1` + err := os.WriteFile(noExtFile, []byte(yamlContent), 0644) + assert.NoError(t, err) + + cmdOpts := config.CmdOptions{} + cmdOpts.Start.File = noExtFile + + err = mockpge.ExecuteFileScript(context.Background(), cmdOpts) + assert.NoError(t, err) + }) + + t.Run("File without extension - SQL content", func(t *testing.T) { + // Create file without extension with SQL content + noExtFile := filepath.Join(tmpDir, "test_no_ext_sql") + sqlContent := "SELECT 1;" + err := os.WriteFile(noExtFile, []byte(sqlContent), 0644) + assert.NoError(t, err) + + // Mock the SQL execution + mockPool.ExpectExec("SELECT 1;").WillReturnResult(pgxmock.NewResult("SELECT", 1)) + + cmdOpts := config.CmdOptions{} + cmdOpts.Start.File = noExtFile + + err = mockpge.ExecuteFileScript(context.Background(), cmdOpts) + assert.NoError(t, err) + }) + + t.Run("File not found error - SQL file", func(t *testing.T) { + cmdOpts := config.CmdOptions{} + cmdOpts.Start.File = "/nonexistent/file.sql" + + err := mockpge.ExecuteFileScript(context.Background(), cmdOpts) + assert.Error(t, err) + }) + + t.Run("File not found error - unknown extension", func(t *testing.T) { + cmdOpts := config.CmdOptions{} + cmdOpts.Start.File = "/nonexistent/file.txt" + + err := mockpge.ExecuteFileScript(context.Background(), cmdOpts) + assert.Error(t, err) + }) + + t.Run("Unknown file extension defaults to content detection", func(t *testing.T) { + // Create file with unknown extension containing SQL + unknownFile := filepath.Join(tmpDir, "test.unknown") + sqlContent := "SELECT 2;" + err := os.WriteFile(unknownFile, []byte(sqlContent), 0644) + assert.NoError(t, err) + + // Mock the SQL execution + mockPool.ExpectExec("SELECT 2;").WillReturnResult(pgxmock.NewResult("SELECT", 1)) + + cmdOpts := config.CmdOptions{} + cmdOpts.Start.File = unknownFile + + err = mockpge.ExecuteFileScript(context.Background(), cmdOpts) + assert.NoError(t, err) + }) + + t.Run("Empty file", func(t *testing.T) { + // Create empty file + emptyFile := filepath.Join(tmpDir, "empty.txt") + err := os.WriteFile(emptyFile, []byte(""), 0644) + assert.NoError(t, err) + + // Mock empty SQL execution + mockPool.ExpectExec("").WillReturnResult(pgxmock.NewResult("SELECT", 0)) + + cmdOpts := config.CmdOptions{} + cmdOpts.Start.File = emptyFile + + err = mockpge.ExecuteFileScript(context.Background(), cmdOpts) + assert.NoError(t, err) + }) + + t.Run("YAML file with whitespace prefix", func(t *testing.T) { + // Create file with leading whitespace before chains: + whitespaceFile := filepath.Join(tmpDir, "whitespace") + yamlContent := ` + +chains: + - name: test_chain` + err := os.WriteFile(whitespaceFile, []byte(yamlContent), 0644) + assert.NoError(t, err) + + cmdOpts := config.CmdOptions{} + cmdOpts.Start.File = whitespaceFile + + err = mockpge.ExecuteFileScript(context.Background(), cmdOpts) + assert.NoError(t, err) + }) + + t.Run("YAML import with replace flag", func(t *testing.T) { + // Create temporary YAML file + yamlFile := filepath.Join(tmpDir, "test_replace.yaml") + yamlContent := `chains: + - name: test_chain_replace + tasks: + - name: test_task + command: SELECT 1` + err := os.WriteFile(yamlFile, []byte(yamlContent), 0644) + assert.NoError(t, err) + + cmdOpts := config.CmdOptions{} + cmdOpts.Start.File = yamlFile + cmdOpts.Start.Validate = false + cmdOpts.Start.Replace = true + + anyArgs := func(i int) []any { + args := make([]any, i) + for j := range i { + args[j] = pgxmock.AnyArg() + } + return args + } + + mockPool.ExpectExec("SELECT timetable\\.delete_job"). + WithArgs("test_chain_replace"). + WillReturnResult(pgxmock.NewResult("DELETE", 1)) + mockPool.ExpectQuery("SELECT EXISTS"). + WithArgs("test_chain_replace"). + WillReturnRows(pgxmock.NewRows([]string{"exists"}).AddRow(false)) + mockPool.ExpectQuery("INSERT INTO timetable\\.chain"). + WithArgs(anyArgs(9)...). + WillReturnRows(pgxmock.NewRows([]string{"chain_id"}).AddRow(1)) + mockPool.ExpectQuery("INSERT INTO timetable\\.task"). + WithArgs(anyArgs(10)...). + WillReturnRows(pgxmock.NewRows([]string{"task_id"}).AddRow(1)) + + err = mockpge.ExecuteFileScript(context.Background(), cmdOpts) + assert.NoError(t, err) + }) + + t.Run("SQL file with multiple statements", func(t *testing.T) { + // Create temporary SQL file with multiple statements + sqlFile := filepath.Join(tmpDir, "multi.sql") + sqlContent := `SELECT 1; +SELECT 2; +INSERT INTO test VALUES (1);` + err := os.WriteFile(sqlFile, []byte(sqlContent), 0644) + assert.NoError(t, err) + + // Mock the SQL execution - use regex pattern to match the content + mockPool.ExpectExec(`SELECT 1;.*SELECT 2;.*INSERT INTO test VALUES \(1\);`).WillReturnResult(pgxmock.NewResult("SELECT", 1)) + + cmdOpts := config.CmdOptions{} + cmdOpts.Start.File = sqlFile + + err = mockpge.ExecuteFileScript(context.Background(), cmdOpts) + assert.NoError(t, err) + }) + + t.Run("Content detection with mixed content", func(t *testing.T) { + // Create file with content that doesn't start with "chains:" + mixedFile := filepath.Join(tmpDir, "mixed_content") + mixedContent := `# This is a comment +# chains: this is just a comment, not actual YAML +SELECT 1;` + err := os.WriteFile(mixedFile, []byte(mixedContent), 0644) + assert.NoError(t, err) + + // Mock the SQL execution since it doesn't start with "chains:" - use regex pattern + mockPool.ExpectExec(`# This is a comment.*# chains:.*SELECT 1;`).WillReturnResult(pgxmock.NewResult("SELECT", 1)) + + cmdOpts := config.CmdOptions{} + cmdOpts.Start.File = mixedFile + + err = mockpge.ExecuteFileScript(context.Background(), cmdOpts) + assert.NoError(t, err) + }) +} diff --git a/internal/pgengine/yaml.go b/internal/pgengine/yaml.go index 9c838959..21fec635 100644 --- a/internal/pgengine/yaml.go +++ b/internal/pgengine/yaml.go @@ -74,8 +74,7 @@ func (pge *PgEngine) LoadYamlChains(ctx context.Context, filePath string, replac func (pge *PgEngine) createChainFromYaml(ctx context.Context, yamlChain *YamlChain) (int64, error) { // Insert chain var chainID int64 - err := pge.ConfigDb.QueryRow(ctx, ` - INSERT INTO timetable.chain ( + err := pge.ConfigDb.QueryRow(ctx, `INSERT INTO timetable.chain ( chain_name, run_at, max_instances, timeout, live, self_destruct, exclusive_execution, client_name, on_error ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) From dcf0a40c90902317457478a0075af6025172087e Mon Sep 17 00:00:00 2001 From: Pavlo Golub Date: Tue, 30 Sep 2025 11:34:11 +0200 Subject: [PATCH 5/7] fix tests and params --- docs/yaml-format.md | 4 +- docs/yaml-usage-guide.md | 3 +- internal/pgengine/bootstrap.go | 16 +- internal/pgengine/bootstrap_test.go | 125 +++------------ internal/pgengine/yaml.go | 56 ++----- internal/pgengine/yaml_test.go | 231 ++++++++++++++++------------ 6 files changed, 171 insertions(+), 264 deletions(-) diff --git a/docs/yaml-format.md b/docs/yaml-format.md index abf076e7..d6c90354 100644 --- a/docs/yaml-format.md +++ b/docs/yaml-format.md @@ -60,7 +60,7 @@ chains: | `name` | `task_name` | TEXT | `null` | Task description | | `kind` | `kind` | ENUM | `'SQL'` | Command type (SQL/PROGRAM/BUILTIN) | | `command` | `command` | TEXT | **required** | Command to execute | -| `parameters` | via `timetable.parameter` | Array of JSONB | `null` | Array of parameter values, each causing separate task execution | +| `parameters` | via `timetable.parameter` | Array of any | `null` | Array of parameter values stored as individual JSONB rows with order_id | | `run_as` | `run_as` | TEXT | `null` | Role for SET ROLE | | `connect_string` | `database_connection` | TEXT | `null` | Connection string | | `ignore_error` | `ignore_error` | BOOLEAN | `false` | Continue on error | @@ -131,5 +131,5 @@ chains: 2. **Unique Names**: Chain names must be unique across the database 3. **Valid Cron**: Schedule must be valid cron format (5 fields) 4. **Valid Kind**: Task kind must be one of: SQL, PROGRAM, BUILTIN -5. **Parameter Types**: Parameters must be strings or numbers (converted to JSONB array) +5. **Parameter Types**: Parameters can be any JSON-compatible type (strings, numbers, booleans, arrays, objects) and are stored as individual JSONB values 6. **Timeout Values**: Must be non-negative integers (milliseconds) diff --git a/docs/yaml-usage-guide.md b/docs/yaml-usage-guide.md index d9c7f19c..e00d7468 100644 --- a/docs/yaml-usage-guide.md +++ b/docs/yaml-usage-guide.md @@ -87,7 +87,8 @@ Each task can have multiple parameter entries, with each entry causing a separat command: "Log" parameters: - "WARNING: Simple message" - - {"level": "WARNING", "details": "Object message"} + - level: "WARNING" + details: "Object message" # BUILTIN: SendMail task (complex object) - name: "mail-task" diff --git a/internal/pgengine/bootstrap.go b/internal/pgengine/bootstrap.go index f3141387..edd8625f 100644 --- a/internal/pgengine/bootstrap.go +++ b/internal/pgengine/bootstrap.go @@ -247,21 +247,7 @@ func (pge *PgEngine) ExecuteFileScript(ctx context.Context, cmdOpts config.CmdOp return pge.ExecuteCustomScripts(ctx, filePath) default: - // Try to detect content type for files without extension - content, err := os.ReadFile(filePath) - if err != nil { - pge.l.WithError(err).Error("cannot read file") - return err - } - - // Check if it looks like YAML (starts with "chains:" or contains YAML markers) - contentStr := strings.TrimSpace(string(content)) - if strings.HasPrefix(contentStr, "chains:") { - pge.l.WithField("file", filePath).Info("Detected YAML content, processing as YAML") - return pge.LoadYamlChains(ctx, filePath, false) - } - pge.l.WithField("file", filePath).Info("Processing as SQL script") - return pge.ExecuteCustomScripts(ctx, filePath) + return errors.New("unsupported file extension: " + fileExt) } } diff --git a/internal/pgengine/bootstrap_test.go b/internal/pgengine/bootstrap_test.go index 3cae9f46..fb6c6865 100644 --- a/internal/pgengine/bootstrap_test.go +++ b/internal/pgengine/bootstrap_test.go @@ -154,6 +154,14 @@ func TestExecuteFileScript(t *testing.T) { // Create temporary directory for test files tmpDir := t.TempDir() + anyArgs := func(i int) []any { + args := make([]any, i) + for j := range i { + args[j] = pgxmock.AnyArg() + } + return args + } + t.Run("SQL file execution", func(t *testing.T) { // Create temporary SQL file sqlFile := filepath.Join(tmpDir, "test.sql") @@ -240,6 +248,16 @@ func TestExecuteFileScript(t *testing.T) { cmdOpts.Start.Validate = false cmdOpts.Start.Replace = false + mockPool.ExpectQuery("SELECT EXISTS"). + WithArgs("test_chain"). + WillReturnRows(pgxmock.NewRows([]string{"exists"}).AddRow(false)) + mockPool.ExpectQuery("INSERT INTO timetable\\.chain"). + WithArgs(anyArgs(9)...). + WillReturnRows(pgxmock.NewRows([]string{"chain_id"}).AddRow(1)) + mockPool.ExpectQuery("INSERT INTO timetable\\.task"). + WithArgs(anyArgs(10)...). + WillReturnRows(pgxmock.NewRows([]string{"task_id"}).AddRow(1)) + err = mockpge.ExecuteFileScript(context.Background(), cmdOpts) assert.NoError(t, err) }) @@ -263,7 +281,7 @@ func TestExecuteFileScript(t *testing.T) { assert.NoError(t, err) }) - t.Run("File without extension - YAML content", func(t *testing.T) { + t.Run("File without extension", func(t *testing.T) { // Create file without extension with YAML content noExtFile := filepath.Join(tmpDir, "test_no_ext") yamlContent := `chains: @@ -278,92 +296,17 @@ func TestExecuteFileScript(t *testing.T) { cmdOpts.Start.File = noExtFile err = mockpge.ExecuteFileScript(context.Background(), cmdOpts) - assert.NoError(t, err) - }) - - t.Run("File without extension - SQL content", func(t *testing.T) { - // Create file without extension with SQL content - noExtFile := filepath.Join(tmpDir, "test_no_ext_sql") - sqlContent := "SELECT 1;" - err := os.WriteFile(noExtFile, []byte(sqlContent), 0644) - assert.NoError(t, err) - - // Mock the SQL execution - mockPool.ExpectExec("SELECT 1;").WillReturnResult(pgxmock.NewResult("SELECT", 1)) - - cmdOpts := config.CmdOptions{} - cmdOpts.Start.File = noExtFile - - err = mockpge.ExecuteFileScript(context.Background(), cmdOpts) - assert.NoError(t, err) - }) - - t.Run("File not found error - SQL file", func(t *testing.T) { - cmdOpts := config.CmdOptions{} - cmdOpts.Start.File = "/nonexistent/file.sql" - - err := mockpge.ExecuteFileScript(context.Background(), cmdOpts) assert.Error(t, err) }) - t.Run("File not found error - unknown extension", func(t *testing.T) { + t.Run("File not found error", func(t *testing.T) { cmdOpts := config.CmdOptions{} - cmdOpts.Start.File = "/nonexistent/file.txt" + cmdOpts.Start.File = "/nonexistent/file.sql" err := mockpge.ExecuteFileScript(context.Background(), cmdOpts) assert.Error(t, err) }) - t.Run("Unknown file extension defaults to content detection", func(t *testing.T) { - // Create file with unknown extension containing SQL - unknownFile := filepath.Join(tmpDir, "test.unknown") - sqlContent := "SELECT 2;" - err := os.WriteFile(unknownFile, []byte(sqlContent), 0644) - assert.NoError(t, err) - - // Mock the SQL execution - mockPool.ExpectExec("SELECT 2;").WillReturnResult(pgxmock.NewResult("SELECT", 1)) - - cmdOpts := config.CmdOptions{} - cmdOpts.Start.File = unknownFile - - err = mockpge.ExecuteFileScript(context.Background(), cmdOpts) - assert.NoError(t, err) - }) - - t.Run("Empty file", func(t *testing.T) { - // Create empty file - emptyFile := filepath.Join(tmpDir, "empty.txt") - err := os.WriteFile(emptyFile, []byte(""), 0644) - assert.NoError(t, err) - - // Mock empty SQL execution - mockPool.ExpectExec("").WillReturnResult(pgxmock.NewResult("SELECT", 0)) - - cmdOpts := config.CmdOptions{} - cmdOpts.Start.File = emptyFile - - err = mockpge.ExecuteFileScript(context.Background(), cmdOpts) - assert.NoError(t, err) - }) - - t.Run("YAML file with whitespace prefix", func(t *testing.T) { - // Create file with leading whitespace before chains: - whitespaceFile := filepath.Join(tmpDir, "whitespace") - yamlContent := ` - -chains: - - name: test_chain` - err := os.WriteFile(whitespaceFile, []byte(yamlContent), 0644) - assert.NoError(t, err) - - cmdOpts := config.CmdOptions{} - cmdOpts.Start.File = whitespaceFile - - err = mockpge.ExecuteFileScript(context.Background(), cmdOpts) - assert.NoError(t, err) - }) - t.Run("YAML import with replace flag", func(t *testing.T) { // Create temporary YAML file yamlFile := filepath.Join(tmpDir, "test_replace.yaml") @@ -380,14 +323,6 @@ chains: cmdOpts.Start.Validate = false cmdOpts.Start.Replace = true - anyArgs := func(i int) []any { - args := make([]any, i) - for j := range i { - args[j] = pgxmock.AnyArg() - } - return args - } - mockPool.ExpectExec("SELECT timetable\\.delete_job"). WithArgs("test_chain_replace"). WillReturnResult(pgxmock.NewResult("DELETE", 1)) @@ -424,22 +359,4 @@ INSERT INTO test VALUES (1);` assert.NoError(t, err) }) - t.Run("Content detection with mixed content", func(t *testing.T) { - // Create file with content that doesn't start with "chains:" - mixedFile := filepath.Join(tmpDir, "mixed_content") - mixedContent := `# This is a comment -# chains: this is just a comment, not actual YAML -SELECT 1;` - err := os.WriteFile(mixedFile, []byte(mixedContent), 0644) - assert.NoError(t, err) - - // Mock the SQL execution since it doesn't start with "chains:" - use regex pattern - mockPool.ExpectExec(`# This is a comment.*# chains:.*SELECT 1;`).WillReturnResult(pgxmock.NewResult("SELECT", 1)) - - cmdOpts := config.CmdOptions{} - cmdOpts.Start.File = mixedFile - - err = mockpge.ExecuteFileScript(context.Background(), cmdOpts) - assert.NoError(t, err) - }) } diff --git a/internal/pgengine/yaml.go b/internal/pgengine/yaml.go index 21fec635..2ba1fc8b 100644 --- a/internal/pgengine/yaml.go +++ b/internal/pgengine/yaml.go @@ -2,9 +2,9 @@ package pgengine import ( "context" + "encoding/json" "fmt" "os" - "path/filepath" "strings" "gopkg.in/yaml.v3" @@ -119,15 +119,21 @@ func (pge *PgEngine) createChainFromYaml(ctx context.Context, yamlChain *YamlCha // Insert parameters if any if len(task.Parameters) > 0 { - params, err := task.ToSQLParameters() - if err != nil { - return 0, fmt.Errorf("failed to convert parameters: %w", err) - } - _, err = pge.ConfigDb.Exec(ctx, - "INSERT INTO timetable.parameter (task_id, order_id, value) VALUES ($1, 1, $2::jsonb)", - taskID, params) - if err != nil { - return 0, fmt.Errorf("failed to insert parameters: %w", err) + for paramIndex, param := range task.Parameters { + orderID := paramIndex + 1 + + // Convert parameter to JSON for JSONB storage + jsonValue, err := json.Marshal(param) + if err != nil { + return 0, fmt.Errorf("failed to marshal parameter %d to JSON: %w", orderID, err) + } + + _, err = pge.ConfigDb.Exec(ctx, + "INSERT INTO timetable.parameter (task_id, order_id, value) VALUES ($1, $2, $3::jsonb)", + taskID, orderID, string(jsonValue)) + if err != nil { + return 0, fmt.Errorf("failed to insert parameter %d: %w", orderID, err) + } } } } @@ -229,12 +235,6 @@ func ParseYamlFile(filePath string) (*YamlConfig, error) { return nil, fmt.Errorf("file not found: %s", filePath) } - // Check file extension - ext := strings.ToLower(filepath.Ext(filePath)) - if ext != ".yaml" && ext != ".yml" { - return nil, fmt.Errorf("file must have .yaml or .yml extension: %s", filePath) - } - // Read file data, err := os.ReadFile(filePath) if err != nil { @@ -259,28 +259,4 @@ func ParseYamlFile(filePath string) (*YamlConfig, error) { return &config, nil } -// ToSQLParameters converts YAML parameters to SQL-compatible format -func (t *YamlTask) ToSQLParameters() (string, error) { - if len(t.Parameters) == 0 { - return "", nil - } - - // Convert to JSON array format for PostgreSQL - params := make([]string, len(t.Parameters)) - for i, param := range t.Parameters { - switch v := param.(type) { - case string: - params[i] = fmt.Sprintf(`"%s"`, strings.ReplaceAll(v, `"`, `\"`)) - case int, int32, int64: - params[i] = fmt.Sprintf("%v", v) - case float32, float64: - params[i] = fmt.Sprintf("%v", v) - case bool: - params[i] = fmt.Sprintf("%t", v) - default: - params[i] = fmt.Sprintf(`"%v"`, v) - } - } - return fmt.Sprintf("[%s]", strings.Join(params, ", ")), nil -} diff --git a/internal/pgengine/yaml_test.go b/internal/pgengine/yaml_test.go index ad566376..ed34600e 100644 --- a/internal/pgengine/yaml_test.go +++ b/internal/pgengine/yaml_test.go @@ -5,6 +5,7 @@ import ( "os" "testing" + "github.com/jackc/pgx/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -253,18 +254,6 @@ func TestParseYamlFile(t *testing.T) { assert.Contains(t, err.Error(), "file not found") }) - t.Run("Invalid file extension", func(t *testing.T) { - tmpfile, err := os.CreateTemp("", "test-*.txt") - require.NoError(t, err) - tmpfileName := tmpfile.Name() - defer os.Remove(tmpfileName) - tmpfile.Close() - - _, err = pgengine.ParseYamlFile(tmpfileName) - assert.Error(t, err) - assert.Contains(t, err.Error(), "file must have .yaml or .yml extension") - }) - t.Run("Invalid YAML syntax", func(t *testing.T) { invalidYaml := `chains: - name: "test" @@ -549,66 +538,141 @@ func TestYamlChainSetDefaults(t *testing.T) { }) } -func TestToSQLParameters(t *testing.T) { - t.Run("No parameters", func(t *testing.T) { - task := &pgengine.YamlTask{} - result, err := task.ToSQLParameters() - assert.NoError(t, err) - assert.Equal(t, "", result) - }) +func TestParameterStorageIntegration(t *testing.T) { + container, cleanup := testutils.SetupPostgresContainer(t) + defer cleanup() - t.Run("String parameters", func(t *testing.T) { - task := &pgengine.YamlTask{ - Parameters: []any{"hello", "world"}, - } - result, err := task.ToSQLParameters() - assert.NoError(t, err) - assert.Equal(t, `["hello", "world"]`, result) - }) + ctx := context.Background() + pge := container.Engine - t.Run("String with quotes", func(t *testing.T) { - task := &pgengine.YamlTask{ - Parameters: []any{`hello "quoted" world`}, - } - result, err := task.ToSQLParameters() - assert.NoError(t, err) - assert.Equal(t, `["hello \"quoted\" world"]`, result) - }) + t.Run("Parameters stored as separate rows with correct order_id", func(t *testing.T) { + yamlContent := `chains: + - name: "test-parameters" + schedule: "0 0 * * *" + tasks: + - name: "mixed-params" + kind: "SQL" + command: "SELECT $1, $2, $3, $4, $5" + parameters: + - "hello world" + - 42 + - 3.14 + - true + - ["array", "value"]` - t.Run("Integer parameters", func(t *testing.T) { - task := &pgengine.YamlTask{ - Parameters: []any{42, int32(100), int64(200)}, - } - result, err := task.ToSQLParameters() - assert.NoError(t, err) - assert.Equal(t, `[42, 100, 200]`, result) - }) + tempFile := createTempYamlFile(t, yamlContent) + defer removeTempFile(t, tempFile) - t.Run("Float parameters", func(t *testing.T) { - task := &pgengine.YamlTask{ - Parameters: []any{3.14, float32(2.71)}, + err := pge.LoadYamlChains(ctx, tempFile, false) + require.NoError(t, err) + + // Get the task ID + var taskID int64 + err = pge.ConfigDb.QueryRow(ctx, + "SELECT task_id FROM timetable.task t JOIN timetable.chain c ON t.chain_id = c.chain_id WHERE c.chain_name = $1", + "test-parameters").Scan(&taskID) + require.NoError(t, err) + + // Verify parameters are stored as separate rows + type paramRow struct { + OrderID int `db:"order_id"` + Value string `db:"value"` } - result, err := task.ToSQLParameters() - assert.NoError(t, err) - assert.Equal(t, `[3.14, 2.71]`, result) + + rows, err := pge.ConfigDb.Query(ctx, + "SELECT order_id, value::text FROM timetable.parameter WHERE task_id = $1 ORDER BY order_id", + taskID) + require.NoError(t, err) + + params, err := pgx.CollectRows(rows, pgx.RowToStructByName[paramRow]) + require.NoError(t, err) + + // Should have 5 parameters + assert.Equal(t, 5, len(params)) + + // Check each parameter + assert.Equal(t, 1, params[0].OrderID) + assert.Equal(t, `"hello world"`, params[0].Value) + + assert.Equal(t, 2, params[1].OrderID) + assert.Equal(t, `42`, params[1].Value) + + assert.Equal(t, 3, params[2].OrderID) + assert.Equal(t, `3.14`, params[2].Value) + + assert.Equal(t, 4, params[3].OrderID) + assert.Equal(t, `true`, params[3].Value) + + assert.Equal(t, 5, params[4].OrderID) + assert.Contains(t, params[4].Value, `["array", "value"]`) }) - t.Run("Boolean parameters", func(t *testing.T) { - task := &pgengine.YamlTask{ - Parameters: []any{true, false}, - } - result, err := task.ToSQLParameters() - assert.NoError(t, err) - assert.Equal(t, `[true, false]`, result) + t.Run("Object parameters stored as JSONB objects", func(t *testing.T) { + yamlContent := `chains: + - name: "test-object-params" + schedule: "0 0 * * *" + tasks: + - name: "object-param" + kind: "BUILTIN" + command: "Log" + parameters: + - {"level": "WARNING", "message": "test", "count": 123}` + + tempFile := createTempYamlFile(t, yamlContent) + defer removeTempFile(t, tempFile) + + err := pge.LoadYamlChains(ctx, tempFile, false) + require.NoError(t, err) + + // Get the task ID + var taskID int64 + err = pge.ConfigDb.QueryRow(ctx, + "SELECT task_id FROM timetable.task t JOIN timetable.chain c ON t.chain_id = c.chain_id WHERE c.chain_name = $1", + "test-object-params").Scan(&taskID) + require.NoError(t, err) + + // Verify object parameter + var value string + err = pge.ConfigDb.QueryRow(ctx, + "SELECT value::text FROM timetable.parameter WHERE task_id = $1 AND order_id = 1", + taskID).Scan(&value) + require.NoError(t, err) + + // Should be a valid JSON object + assert.Contains(t, value, `"level"`) + assert.Contains(t, value, `"WARNING"`) + assert.Contains(t, value, `"message"`) + assert.Contains(t, value, `"test"`) + assert.Contains(t, value, `"count"`) + assert.Contains(t, value, `123`) }) - t.Run("Mixed parameter types", func(t *testing.T) { - task := &pgengine.YamlTask{ - Parameters: []any{"text", 42, 3.14, true, nil}, - } - result, err := task.ToSQLParameters() - assert.NoError(t, err) - assert.Equal(t, `["text", 42, 3.14, true, ""]`, result) + t.Run("No parameters creates no parameter rows", func(t *testing.T) { + yamlContent := `chains: + - name: "test-no-params" + schedule: "0 0 * * *" + tasks: + - name: "no-param" + kind: "SQL" + command: "SELECT 1"` + + tempFile := createTempYamlFile(t, yamlContent) + defer removeTempFile(t, tempFile) + + err := pge.LoadYamlChains(ctx, tempFile, false) + require.NoError(t, err) + + // Get the task ID + var count int + err = pge.ConfigDb.QueryRow(ctx, ` +SELECT COUNT(*) +FROM timetable.task t + JOIN timetable.chain c ON t.chain_id = c.chain_id + JOIN timetable.parameter p ON t.task_id = p.task_id + WHERE c.chain_name = $1`, + "test-no-params").Scan(&count) + require.NoError(t, err) + assert.Equal(t, 0, count) }) } @@ -714,7 +778,7 @@ func TestLoadYamlChainsMultiTask(t *testing.T) { "SELECT COUNT(*) FROM timetable.parameter p JOIN timetable.task t ON p.task_id = t.task_id WHERE t.chain_id = $1", chainID).Scan(¶mCount) require.NoError(t, err) - assert.Equal(t, 2, paramCount) // 2 tasks with parameters + assert.Equal(t, 4, paramCount) // First task has 2 params, second task has 2 params = 4 total }) t.Run("Chain already exists without replace", func(t *testing.T) { @@ -855,43 +919,6 @@ func TestCreateChainFromYamlEdgeCases(t *testing.T) { }) } -func TestToSQLParametersErrorHandling(t *testing.T) { - t.Run("Error in parameter conversion", func(t *testing.T) { - task := &pgengine.YamlTask{ - Parameters: []any{ - make(chan int), // unsupported type that can't be converted - }, - } - result, err := task.ToSQLParameters() - assert.NoError(t, err) // Function doesn't actually return errors, just converts to string - assert.Contains(t, result, "0x") // channel representation - }) - - t.Run("Empty array parameters", func(t *testing.T) { - task := &pgengine.YamlTask{ - Parameters: []any{[]any{}}, - } - result, err := task.ToSQLParameters() - assert.NoError(t, err) - assert.Equal(t, `["[]"]`, result) - }) - - t.Run("Complex nested structures", func(t *testing.T) { - task := &pgengine.YamlTask{ - Parameters: []any{ - map[string]any{ - "nested": map[string]any{ - "deep": []any{1, 2, 3}, - }, - }, - }, - } - result, err := task.ToSQLParameters() - assert.NoError(t, err) - assert.Contains(t, result, "nested") - }) -} - func TestCreateChainFromYamlErrorHandling(t *testing.T) { container, cleanup := testutils.SetupPostgresContainer(t) defer cleanup() @@ -1055,7 +1082,7 @@ func TestCreateChainFromYamlErrorHandling(t *testing.T) { WHERE t.chain_id = $1`, chainID).Scan(¶mCount) require.NoError(t, err) - assert.Equal(t, 3, paramCount) // 3 tasks with parameters + assert.Equal(t, 5, paramCount) // sql-task: 2 params, program-task: 2 params, builtin-task: 1 param = 5 total }) } From 963cac6b08e2b4a0052b339cd3be36778449d431 Mon Sep 17 00:00:00 2001 From: Pavlo Golub Date: Tue, 30 Sep 2025 12:37:45 +0200 Subject: [PATCH 6/7] add coverage --- internal/pgengine/bootstrap_test.go | 97 ++++++++++++++++++++++++++--- internal/pgengine/yaml.go | 12 ++-- internal/pgengine/yaml_test.go | 5 ++ 3 files changed, 99 insertions(+), 15 deletions(-) diff --git a/internal/pgengine/bootstrap_test.go b/internal/pgengine/bootstrap_test.go index fb6c6865..717bc5c6 100644 --- a/internal/pgengine/bootstrap_test.go +++ b/internal/pgengine/bootstrap_test.go @@ -17,6 +17,14 @@ import ( "github.com/stretchr/testify/assert" ) +func anyArgs(i int) []any { + args := make([]any, i) + for j := range i { + args[j] = pgxmock.AnyArg() + } + return args +} + func TestExecuteSchemaScripts(t *testing.T) { initmockdb(t) defer mockPool.Close() @@ -146,6 +154,87 @@ func TestTryLockClientName(t *testing.T) { }) } +func TestCreateChainFromYamlErrors(t *testing.T) { + initmockdb(t) + defer mockPool.Close() + mockpge := pgengine.NewDB(mockPool, "pgengine_unit_test") + + t.Run("Database error during chain creation", func(t *testing.T) { + mockPool.ExpectQuery(`INSERT INTO timetable.chain`). + WithArgs(anyArgs(9)...). + WillReturnError(fmt.Errorf("simulated DB error")) + _, err := mockpge.CreateChainFromYaml(ctx, &pgengine.YamlChain{}) + assert.Error(t, err) + assert.NoError(t, mockPool.ExpectationsWereMet()) + }) + + t.Run("Database error during task creation", func(t *testing.T) { + mockPool.ExpectQuery(`INSERT INTO timetable.chain`). + WithArgs(anyArgs(9)...). + WillReturnRows(pgxmock.NewRows([]string{"chain_id"}).AddRow(1)) + mockPool.ExpectQuery(`INSERT INTO timetable.task`). + WithArgs(anyArgs(10)...). + WillReturnError(fmt.Errorf("simulated DB error on task")) + + _, err := mockpge.CreateChainFromYaml(ctx, &pgengine.YamlChain{ + Chain: pgengine.Chain{ChainName: "test-chain"}, + Schedule: "0 0 * * *", + Tasks: []pgengine.YamlTask{ + {ChainTask: pgengine.ChainTask{Command: "SELECT 1", Kind: "SQL"}}, + }, + }) + assert.Error(t, err) + assert.NoError(t, mockPool.ExpectationsWereMet()) + }) + + t.Run("Database error during parameter unmarshalling", func(t *testing.T) { + mockPool.ExpectQuery(`INSERT INTO timetable.chain`). + WithArgs(anyArgs(9)...). + WillReturnRows(pgxmock.NewRows([]string{"chain_id"}).AddRow(1)) + mockPool.ExpectQuery(`INSERT INTO timetable.task`). + WithArgs(anyArgs(10)...). + WillReturnRows(pgxmock.NewRows([]string{"task_id"}).AddRow(1)) + + _, err := mockpge.CreateChainFromYaml(ctx, &pgengine.YamlChain{ + Chain: pgengine.Chain{ChainName: "test-chain"}, + Schedule: "0 0 * * *", + Tasks: []pgengine.YamlTask{ + { + ChainTask: pgengine.ChainTask{Command: "SELECT 1", Kind: "SQL"}, + Parameters: []any{func() {}}, // functions cannot be marshalled to JSON + }, + }, + }) + assert.Error(t, err) + assert.NoError(t, mockPool.ExpectationsWereMet()) + }) + + t.Run("Database error during parameter creation", func(t *testing.T) { + mockPool.ExpectQuery(`INSERT INTO timetable.chain`). + WithArgs(anyArgs(9)...). + WillReturnRows(pgxmock.NewRows([]string{"chain_id"}).AddRow(1)) + mockPool.ExpectQuery(`INSERT INTO timetable.task`). + WithArgs(anyArgs(10)...). + WillReturnRows(pgxmock.NewRows([]string{"task_id"}).AddRow(1)) + mockPool.ExpectExec(`INSERT INTO timetable.parameter`). + WithArgs(anyArgs(3)...). + WillReturnError(fmt.Errorf("simulated DB error on parameter")) + + _, err := mockpge.CreateChainFromYaml(ctx, &pgengine.YamlChain{ + Chain: pgengine.Chain{ChainName: "test-chain"}, + Schedule: "0 0 * * *", + Tasks: []pgengine.YamlTask{ + { + ChainTask: pgengine.ChainTask{Command: "SELECT 1", Kind: "SQL"}, + Parameters: []any{"foo"}, + }, + }, + }) + assert.Error(t, err) + assert.NoError(t, mockPool.ExpectationsWereMet()) + }) +} + func TestExecuteFileScript(t *testing.T) { initmockdb(t) defer mockPool.Close() @@ -154,14 +243,6 @@ func TestExecuteFileScript(t *testing.T) { // Create temporary directory for test files tmpDir := t.TempDir() - anyArgs := func(i int) []any { - args := make([]any, i) - for j := range i { - args[j] = pgxmock.AnyArg() - } - return args - } - t.Run("SQL file execution", func(t *testing.T) { // Create temporary SQL file sqlFile := filepath.Join(tmpDir, "test.sql") diff --git a/internal/pgengine/yaml.go b/internal/pgengine/yaml.go index 2ba1fc8b..92a63c73 100644 --- a/internal/pgengine/yaml.go +++ b/internal/pgengine/yaml.go @@ -59,7 +59,7 @@ func (pge *PgEngine) LoadYamlChains(ctx context.Context, filePath string, replac } // Multi-task chain - use direct SQL - chainID, err := pge.createChainFromYaml(ctx, &yamlChain) + chainID, err := pge.CreateChainFromYaml(ctx, &yamlChain) if err != nil { return fmt.Errorf("failed to create multi-task chain %s: %w", yamlChain.ChainName, err) } @@ -70,8 +70,8 @@ func (pge *PgEngine) LoadYamlChains(ctx context.Context, filePath string, replac return nil } -// createChainFromYaml creates a multi-task chain using direct SQL inserts -func (pge *PgEngine) createChainFromYaml(ctx context.Context, yamlChain *YamlChain) (int64, error) { +// CreateChainFromYaml creates a multi-task chain using direct SQL inserts +func (pge *PgEngine) CreateChainFromYaml(ctx context.Context, yamlChain *YamlChain) (int64, error) { // Insert chain var chainID int64 err := pge.ConfigDb.QueryRow(ctx, `INSERT INTO timetable.chain ( @@ -121,13 +121,13 @@ func (pge *PgEngine) createChainFromYaml(ctx context.Context, yamlChain *YamlCha if len(task.Parameters) > 0 { for paramIndex, param := range task.Parameters { orderID := paramIndex + 1 - + // Convert parameter to JSON for JSONB storage jsonValue, err := json.Marshal(param) if err != nil { return 0, fmt.Errorf("failed to marshal parameter %d to JSON: %w", orderID, err) } - + _, err = pge.ConfigDb.Exec(ctx, "INSERT INTO timetable.parameter (task_id, order_id, value) VALUES ($1, $2, $3::jsonb)", taskID, orderID, string(jsonValue)) @@ -258,5 +258,3 @@ func ParseYamlFile(filePath string) (*YamlConfig, error) { return &config, nil } - - diff --git a/internal/pgengine/yaml_test.go b/internal/pgengine/yaml_test.go index ed34600e..b98bf9b6 100644 --- a/internal/pgengine/yaml_test.go +++ b/internal/pgengine/yaml_test.go @@ -254,6 +254,11 @@ func TestParseYamlFile(t *testing.T) { assert.Contains(t, err.Error(), "file not found") }) + t.Run("File cannot be read", func(t *testing.T) { + _, err := pgengine.ParseYamlFile(".") + assert.Error(t, err) + }) + t.Run("Invalid YAML syntax", func(t *testing.T) { invalidYaml := `chains: - name: "test" From f545f91b254aedda34a3c4a4da36234cfc9a21b0 Mon Sep 17 00:00:00 2001 From: Pavlo Golub Date: Tue, 30 Sep 2025 12:59:54 +0200 Subject: [PATCH 7/7] rework some tests with pgxmock --- internal/pgengine/bootstrap_test.go | 81 ----- internal/pgengine/yaml_test.go | 516 +++++++++++++++------------- 2 files changed, 274 insertions(+), 323 deletions(-) diff --git a/internal/pgengine/bootstrap_test.go b/internal/pgengine/bootstrap_test.go index 717bc5c6..0be8f85f 100644 --- a/internal/pgengine/bootstrap_test.go +++ b/internal/pgengine/bootstrap_test.go @@ -154,87 +154,6 @@ func TestTryLockClientName(t *testing.T) { }) } -func TestCreateChainFromYamlErrors(t *testing.T) { - initmockdb(t) - defer mockPool.Close() - mockpge := pgengine.NewDB(mockPool, "pgengine_unit_test") - - t.Run("Database error during chain creation", func(t *testing.T) { - mockPool.ExpectQuery(`INSERT INTO timetable.chain`). - WithArgs(anyArgs(9)...). - WillReturnError(fmt.Errorf("simulated DB error")) - _, err := mockpge.CreateChainFromYaml(ctx, &pgengine.YamlChain{}) - assert.Error(t, err) - assert.NoError(t, mockPool.ExpectationsWereMet()) - }) - - t.Run("Database error during task creation", func(t *testing.T) { - mockPool.ExpectQuery(`INSERT INTO timetable.chain`). - WithArgs(anyArgs(9)...). - WillReturnRows(pgxmock.NewRows([]string{"chain_id"}).AddRow(1)) - mockPool.ExpectQuery(`INSERT INTO timetable.task`). - WithArgs(anyArgs(10)...). - WillReturnError(fmt.Errorf("simulated DB error on task")) - - _, err := mockpge.CreateChainFromYaml(ctx, &pgengine.YamlChain{ - Chain: pgengine.Chain{ChainName: "test-chain"}, - Schedule: "0 0 * * *", - Tasks: []pgengine.YamlTask{ - {ChainTask: pgengine.ChainTask{Command: "SELECT 1", Kind: "SQL"}}, - }, - }) - assert.Error(t, err) - assert.NoError(t, mockPool.ExpectationsWereMet()) - }) - - t.Run("Database error during parameter unmarshalling", func(t *testing.T) { - mockPool.ExpectQuery(`INSERT INTO timetable.chain`). - WithArgs(anyArgs(9)...). - WillReturnRows(pgxmock.NewRows([]string{"chain_id"}).AddRow(1)) - mockPool.ExpectQuery(`INSERT INTO timetable.task`). - WithArgs(anyArgs(10)...). - WillReturnRows(pgxmock.NewRows([]string{"task_id"}).AddRow(1)) - - _, err := mockpge.CreateChainFromYaml(ctx, &pgengine.YamlChain{ - Chain: pgengine.Chain{ChainName: "test-chain"}, - Schedule: "0 0 * * *", - Tasks: []pgengine.YamlTask{ - { - ChainTask: pgengine.ChainTask{Command: "SELECT 1", Kind: "SQL"}, - Parameters: []any{func() {}}, // functions cannot be marshalled to JSON - }, - }, - }) - assert.Error(t, err) - assert.NoError(t, mockPool.ExpectationsWereMet()) - }) - - t.Run("Database error during parameter creation", func(t *testing.T) { - mockPool.ExpectQuery(`INSERT INTO timetable.chain`). - WithArgs(anyArgs(9)...). - WillReturnRows(pgxmock.NewRows([]string{"chain_id"}).AddRow(1)) - mockPool.ExpectQuery(`INSERT INTO timetable.task`). - WithArgs(anyArgs(10)...). - WillReturnRows(pgxmock.NewRows([]string{"task_id"}).AddRow(1)) - mockPool.ExpectExec(`INSERT INTO timetable.parameter`). - WithArgs(anyArgs(3)...). - WillReturnError(fmt.Errorf("simulated DB error on parameter")) - - _, err := mockpge.CreateChainFromYaml(ctx, &pgengine.YamlChain{ - Chain: pgengine.Chain{ChainName: "test-chain"}, - Schedule: "0 0 * * *", - Tasks: []pgengine.YamlTask{ - { - ChainTask: pgengine.ChainTask{Command: "SELECT 1", Kind: "SQL"}, - Parameters: []any{"foo"}, - }, - }, - }) - assert.Error(t, err) - assert.NoError(t, mockPool.ExpectationsWereMet()) - }) -} - func TestExecuteFileScript(t *testing.T) { initmockdb(t) defer mockPool.Close() diff --git a/internal/pgengine/yaml_test.go b/internal/pgengine/yaml_test.go index b98bf9b6..38da4e4e 100644 --- a/internal/pgengine/yaml_test.go +++ b/internal/pgengine/yaml_test.go @@ -2,10 +2,12 @@ package pgengine_test import ( "context" + "fmt" "os" "testing" "github.com/jackc/pgx/v5" + "github.com/pashagolub/pgxmock/v4" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -684,11 +686,9 @@ FROM timetable.task t func TestNullString(t *testing.T) { // Note: nullString is not exported, so we test it indirectly through chain creation t.Run("Indirect test via chain creation", func(t *testing.T) { - container, cleanup := testutils.SetupPostgresContainer(t) - defer cleanup() - - ctx := context.Background() - pge := container.Engine + initmockdb(t) + defer mockPool.Close() + mockpge := pgengine.NewDB(mockPool, "pgengine_unit_test") yamlContent := `chains: - name: "test-null-strings" @@ -705,26 +705,27 @@ func TestNullString(t *testing.T) { tmpfile := createTempYamlFile(t, yamlContent) defer removeTempFile(t, tmpfile) - err := pge.LoadYamlChains(ctx, tmpfile, false) - require.NoError(t, err) - - // Verify NULL values in database - var clientName, onError any - err = pge.ConfigDb.QueryRow(ctx, - "SELECT client_name, on_error FROM timetable.chain WHERE chain_name = $1", - "test-null-strings").Scan(&clientName, &onError) + // Mock chain and task creation with empty strings converted to NULL + mockPool.ExpectQuery("SELECT EXISTS"). + WithArgs("test-null-strings"). + WillReturnRows(pgxmock.NewRows([]string{"exists"}).AddRow(false)) + mockPool.ExpectQuery(`INSERT INTO timetable\.chain`). + WithArgs(anyArgs(9)...). + WillReturnRows(pgxmock.NewRows([]string{"chain_id"}).AddRow(1)) + mockPool.ExpectQuery(`INSERT INTO timetable\.task`). + WithArgs(anyArgs(10)...). + WillReturnRows(pgxmock.NewRows([]string{"task_id"}).AddRow(1)) + + err := mockpge.LoadYamlChains(context.Background(), tmpfile, false) require.NoError(t, err) - assert.Nil(t, clientName) - assert.Nil(t, onError) + assert.NoError(t, mockPool.ExpectationsWereMet()) }) } func TestLoadYamlChainsMultiTask(t *testing.T) { - container, cleanup := testutils.SetupPostgresContainer(t) - defer cleanup() - - ctx := context.Background() - pge := container.Engine + initmockdb(t) + defer mockPool.Close() + mockpge := pgengine.NewDB(mockPool, "pgengine_unit_test") t.Run("Multi-task chain creation", func(t *testing.T) { yamlContent := `chains: @@ -758,32 +759,37 @@ func TestLoadYamlChainsMultiTask(t *testing.T) { tmpfile := createTempYamlFile(t, yamlContent) defer removeTempFile(t, tmpfile) - err := pge.LoadYamlChains(ctx, tmpfile, false) - require.NoError(t, err) - - // Verify chain was created - var chainID int64 - err = pge.ConfigDb.QueryRow(ctx, - "SELECT chain_id FROM timetable.chain WHERE chain_name = $1", - "multi-task-chain").Scan(&chainID) - require.NoError(t, err) - assert.Greater(t, chainID, int64(0)) - - // Verify tasks were created - var taskCount int - err = pge.ConfigDb.QueryRow(ctx, - "SELECT COUNT(*) FROM timetable.task WHERE chain_id = $1", - chainID).Scan(&taskCount) + // Mock chain creation + mockPool.ExpectQuery("SELECT EXISTS"). + WithArgs("multi-task-chain"). + WillReturnRows(pgxmock.NewRows([]string{"exists"}).AddRow(false)) + mockPool.ExpectQuery(`INSERT INTO timetable\.chain`). + WithArgs(anyArgs(9)...). + WillReturnRows(pgxmock.NewRows([]string{"chain_id"}).AddRow(1)) + + // Mock first task creation + mockPool.ExpectQuery(`INSERT INTO timetable\.task`). + WithArgs(anyArgs(10)...). + WillReturnRows(pgxmock.NewRows([]string{"task_id"}).AddRow(1)) + // Mock first task parameters (2 parameters) + mockPool.ExpectExec(`INSERT INTO timetable\.parameter`). + WithArgs(anyArgs(3)...). + WillReturnResult(pgxmock.NewResult("INSERT", 1)). + Times(2) + + // Mock second task creation + mockPool.ExpectQuery(`INSERT INTO timetable\.task`). + WithArgs(anyArgs(10)...). + WillReturnRows(pgxmock.NewRows([]string{"task_id"}).AddRow(2)) + // Mock second task parameters (2 parameters) + mockPool.ExpectExec(`INSERT INTO timetable\.parameter`). + WithArgs(anyArgs(3)...). + WillReturnResult(pgxmock.NewResult("INSERT", 1)). + Times(2) + + err := mockpge.LoadYamlChains(context.Background(), tmpfile, false) require.NoError(t, err) - assert.Equal(t, 2, taskCount) - - // Verify task parameters were created - var paramCount int - err = pge.ConfigDb.QueryRow(ctx, - "SELECT COUNT(*) FROM timetable.parameter p JOIN timetable.task t ON p.task_id = t.task_id WHERE t.chain_id = $1", - chainID).Scan(¶mCount) - require.NoError(t, err) - assert.Equal(t, 4, paramCount) // First task has 2 params, second task has 2 params = 4 total + assert.NoError(t, mockPool.ExpectationsWereMet()) }) t.Run("Chain already exists without replace", func(t *testing.T) { @@ -797,14 +803,15 @@ func TestLoadYamlChainsMultiTask(t *testing.T) { tmpfile := createTempYamlFile(t, yamlContent) defer removeTempFile(t, tmpfile) - // First load should succeed - err := pge.LoadYamlChains(ctx, tmpfile, false) - require.NoError(t, err) + // Mock chain already exists + mockPool.ExpectQuery("SELECT EXISTS"). + WithArgs("existing-chain"). + WillReturnRows(pgxmock.NewRows([]string{"exists"}).AddRow(true)) - // Second load without replace should fail - err = pge.LoadYamlChains(ctx, tmpfile, false) + err := mockpge.LoadYamlChains(context.Background(), tmpfile, false) assert.Error(t, err) assert.Contains(t, err.Error(), "already exists") + assert.NoError(t, mockPool.ExpectationsWereMet()) }) t.Run("Database error during chain creation", func(t *testing.T) { @@ -820,21 +827,19 @@ func TestLoadYamlChainsMultiTask(t *testing.T) { defer removeTempFile(t, tmpfile) // This should fail at YAML validation level due to invalid cron format - err := pge.LoadYamlChains(ctx, tmpfile, false) + err := mockpge.LoadYamlChains(context.Background(), tmpfile, false) assert.Error(t, err) assert.Contains(t, err.Error(), "failed to parse YAML file") }) } func TestLoadYamlChainsErrorCases(t *testing.T) { - container, cleanup := testutils.SetupPostgresContainer(t) - defer cleanup() - - ctx := context.Background() - pge := container.Engine + initmockdb(t) + defer mockPool.Close() + mockpge := pgengine.NewDB(mockPool, "pgengine_unit_test") t.Run("Invalid YAML file", func(t *testing.T) { - err := pge.LoadYamlChains(ctx, "/non/existent/file.yaml", false) + err := mockpge.LoadYamlChains(context.Background(), "/non/existent/file.yaml", false) assert.Error(t, err) assert.Contains(t, err.Error(), "failed to parse YAML file") }) @@ -849,18 +854,16 @@ func TestLoadYamlChainsErrorCases(t *testing.T) { tmpfile := createTempYamlFile(t, invalidYaml) defer removeTempFile(t, tmpfile) - err := pge.LoadYamlChains(ctx, tmpfile, false) + err := mockpge.LoadYamlChains(context.Background(), tmpfile, false) assert.Error(t, err) assert.Contains(t, err.Error(), "failed to parse YAML file") }) } func TestCreateChainFromYamlEdgeCases(t *testing.T) { - container, cleanup := testutils.SetupPostgresContainer(t) - defer cleanup() - - ctx := context.Background() - pge := container.Engine + initmockdb(t) + defer mockPool.Close() + mockpge := pgengine.NewDB(mockPool, "pgengine_unit_test") t.Run("Task with no parameters", func(t *testing.T) { yamlContent := `chains: @@ -878,19 +881,26 @@ func TestCreateChainFromYamlEdgeCases(t *testing.T) { tmpfile := createTempYamlFile(t, yamlContent) defer removeTempFile(t, tmpfile) - err := pge.LoadYamlChains(ctx, tmpfile, false) + // Mock chain creation and tasks without parameters + mockPool.ExpectQuery("SELECT EXISTS"). + WithArgs("no-params-chain"). + WillReturnRows(pgxmock.NewRows([]string{"exists"}).AddRow(false)) + mockPool.ExpectQuery(`INSERT INTO timetable\.chain`). + WithArgs(anyArgs(9)...). + WillReturnRows(pgxmock.NewRows([]string{"chain_id"}).AddRow(1)) + + // Mock first task creation (no parameters) + mockPool.ExpectQuery(`INSERT INTO timetable\.task`). + WithArgs(anyArgs(10)...). + WillReturnRows(pgxmock.NewRows([]string{"task_id"}).AddRow(1)) + // Mock second task creation (empty parameters) + mockPool.ExpectQuery(`INSERT INTO timetable\.task`). + WithArgs(anyArgs(10)...). + WillReturnRows(pgxmock.NewRows([]string{"task_id"}).AddRow(2)) + + err := mockpge.LoadYamlChains(context.Background(), tmpfile, false) require.NoError(t, err) - - // Verify no parameters were inserted - var paramCount int - err = pge.ConfigDb.QueryRow(ctx, - `SELECT COUNT(*) FROM timetable.parameter p - JOIN timetable.task t ON p.task_id = t.task_id - JOIN timetable.chain c ON t.chain_id = c.chain_id - WHERE c.chain_name = $1`, - "no-params-chain").Scan(¶mCount) - require.NoError(t, err) - assert.Equal(t, 0, paramCount) + assert.NoError(t, mockPool.ExpectationsWereMet()) }) t.Run("Complex parameter types", func(t *testing.T) { @@ -907,29 +917,31 @@ func TestCreateChainFromYamlEdgeCases(t *testing.T) { tmpfile := createTempYamlFile(t, yamlContent) defer removeTempFile(t, tmpfile) - err := pge.LoadYamlChains(ctx, tmpfile, false) + // Mock chain and task creation + mockPool.ExpectQuery("SELECT EXISTS"). + WithArgs("complex-params-chain"). + WillReturnRows(pgxmock.NewRows([]string{"exists"}).AddRow(false)) + mockPool.ExpectQuery(`INSERT INTO timetable\.chain`). + WithArgs(anyArgs(9)...). + WillReturnRows(pgxmock.NewRows([]string{"chain_id"}).AddRow(1)) + mockPool.ExpectQuery(`INSERT INTO timetable\.task`). + WithArgs(anyArgs(10)...). + WillReturnRows(pgxmock.NewRows([]string{"task_id"}).AddRow(1)) + // Mock parameter insertion + mockPool.ExpectExec(`INSERT INTO timetable\.parameter`). + WithArgs(anyArgs(3)...). + WillReturnResult(pgxmock.NewResult("INSERT", 1)) + + err := mockpge.LoadYamlChains(context.Background(), tmpfile, false) require.NoError(t, err) - - // Verify parameter was stored - var paramValue string - err = pge.ConfigDb.QueryRow(ctx, - `SELECT p.value FROM timetable.parameter p - JOIN timetable.task t ON p.task_id = t.task_id - JOIN timetable.chain c ON t.chain_id = c.chain_id - WHERE c.chain_name = $1`, - "complex-params-chain").Scan(¶mValue) - require.NoError(t, err) - assert.Contains(t, paramValue, "key") - assert.Contains(t, paramValue, "value") + assert.NoError(t, mockPool.ExpectationsWereMet()) }) } func TestCreateChainFromYamlErrorHandling(t *testing.T) { - container, cleanup := testutils.SetupPostgresContainer(t) - defer cleanup() - - ctx := context.Background() - pge := container.Engine + initmockdb(t) + defer mockPool.Close() + mockpge := pgengine.NewDB(mockPool, "pgengine_unit_test") t.Run("Multi-task chain with invalid parameter conversion", func(t *testing.T) { yamlContent := `chains: @@ -947,23 +959,31 @@ func TestCreateChainFromYamlErrorHandling(t *testing.T) { tmpfile := createTempYamlFile(t, yamlContent) defer removeTempFile(t, tmpfile) - err := pge.LoadYamlChains(ctx, tmpfile, false) + // Mock chain creation + mockPool.ExpectQuery("SELECT EXISTS"). + WithArgs("param-error-chain"). + WillReturnRows(pgxmock.NewRows([]string{"exists"}).AddRow(false)) + mockPool.ExpectQuery(`INSERT INTO timetable\.chain`). + WithArgs(anyArgs(9)...). + WillReturnRows(pgxmock.NewRows([]string{"chain_id"}).AddRow(1)) + + // Mock first task with complex parameter + mockPool.ExpectQuery(`INSERT INTO timetable\.task`). + WithArgs(anyArgs(10)...). + WillReturnRows(pgxmock.NewRows([]string{"task_id"}).AddRow(1)) + mockPool.ExpectExec(`INSERT INTO timetable\.parameter`). + WithArgs(anyArgs(3)...). + WillReturnResult(pgxmock.NewResult("INSERT", 1)) + + // Mock second task (no parameters) + mockPool.ExpectQuery(`INSERT INTO timetable\.task`). + WithArgs(anyArgs(10)...). + WillReturnRows(pgxmock.NewRows([]string{"task_id"}).AddRow(2)) + + err := mockpge.LoadYamlChains(context.Background(), tmpfile, false) // Should succeed even with complex parameters require.NoError(t, err) - - // Verify chain and tasks were created - var chainID int64 - err = pge.ConfigDb.QueryRow(ctx, - "SELECT chain_id FROM timetable.chain WHERE chain_name = $1", - "param-error-chain").Scan(&chainID) - require.NoError(t, err) - - var taskCount int - err = pge.ConfigDb.QueryRow(ctx, - "SELECT COUNT(*) FROM timetable.task WHERE chain_id = $1", - chainID).Scan(&taskCount) - require.NoError(t, err) - assert.Equal(t, 2, taskCount) + assert.NoError(t, mockPool.ExpectationsWereMet()) }) t.Run("Multi-task chain with various field types", func(t *testing.T) { @@ -1005,104 +1025,51 @@ func TestCreateChainFromYamlErrorHandling(t *testing.T) { tmpfile := createTempYamlFile(t, yamlContent) defer removeTempFile(t, tmpfile) - err := pge.LoadYamlChains(ctx, tmpfile, false) - require.NoError(t, err) - - // Verify all chain properties - var chainID int64 - var schedule, clientName, onError string - var live, selfDestruct, exclusive bool - var maxInstances, timeout int - err = pge.ConfigDb.QueryRow(ctx, - `SELECT chain_id, run_at, client_name, on_error, live, - self_destruct, exclusive_execution, max_instances, timeout - FROM timetable.chain WHERE chain_name = $1`, - "comprehensive-multi-task").Scan( - &chainID, &schedule, &clientName, &onError, &live, - &selfDestruct, &exclusive, &maxInstances, &timeout) - require.NoError(t, err) - - assert.Equal(t, "@every 1h", schedule) - assert.Equal(t, "test-client-multi", clientName) - assert.Equal(t, "IGNORE", onError) - assert.False(t, live) - assert.True(t, selfDestruct) - assert.False(t, exclusive) - assert.Equal(t, 3, maxInstances) - assert.Equal(t, 600, timeout) - - // Verify all tasks were created with correct properties - rows, err := pge.ConfigDb.Query(ctx, - `SELECT task_name, kind, command, ignore_error, autonomous, - timeout, run_as, database_connection - FROM timetable.task WHERE chain_id = $1 ORDER BY task_order`, - chainID) - require.NoError(t, err) - defer rows.Close() - - expectedTasks := []struct { - name string - kind string - command string - ignoreErr bool - auto bool - timeout int - runAs *string - dbConn *string - }{ - {"sql-task", "SQL", "SELECT $1, $2", true, false, 120, stringPtr("test_user"), stringPtr("dbname=test")}, - {"program-task", "PROGRAM", "echo", false, true, 60, nil, nil}, - {"builtin-task", "BUILTIN", "Sleep", false, false, 10, nil, nil}, - } - - taskIdx := 0 - for rows.Next() { - var name, kind, command string - var ignoreErr, auto bool - var timeout int - var runAs, dbConn *string - - err := rows.Scan(&name, &kind, &command, &ignoreErr, &auto, &timeout, &runAs, &dbConn) - require.NoError(t, err) - - expected := expectedTasks[taskIdx] - assert.Equal(t, expected.name, name) - assert.Equal(t, expected.kind, kind) - assert.Equal(t, expected.command, command) - assert.Equal(t, expected.ignoreErr, ignoreErr) - assert.Equal(t, expected.auto, auto) - assert.Equal(t, expected.timeout, timeout) - assert.Equal(t, expected.runAs, runAs) - assert.Equal(t, expected.dbConn, dbConn) - - taskIdx++ - } - assert.Equal(t, 3, taskIdx) - - // Verify parameters were stored correctly - var paramCount int - err = pge.ConfigDb.QueryRow(ctx, - `SELECT COUNT(*) FROM timetable.parameter p - JOIN timetable.task t ON p.task_id = t.task_id - WHERE t.chain_id = $1`, - chainID).Scan(¶mCount) + // Mock comprehensive chain creation + mockPool.ExpectQuery("SELECT EXISTS"). + WithArgs("comprehensive-multi-task"). + WillReturnRows(pgxmock.NewRows([]string{"exists"}).AddRow(false)) + mockPool.ExpectQuery(`INSERT INTO timetable\.chain`). + WithArgs(anyArgs(9)...). + WillReturnRows(pgxmock.NewRows([]string{"chain_id"}).AddRow(1)) + + // Mock sql-task creation with 2 parameters + mockPool.ExpectQuery(`INSERT INTO timetable\.task`). + WithArgs(anyArgs(10)...). + WillReturnRows(pgxmock.NewRows([]string{"task_id"}).AddRow(1)) + mockPool.ExpectExec(`INSERT INTO timetable\.parameter`). + WithArgs(anyArgs(3)...). + WillReturnResult(pgxmock.NewResult("INSERT", 1)). + Times(2) + + // Mock program-task creation with 2 parameters + mockPool.ExpectQuery(`INSERT INTO timetable\.task`). + WithArgs(anyArgs(10)...). + WillReturnRows(pgxmock.NewRows([]string{"task_id"}).AddRow(2)) + mockPool.ExpectExec(`INSERT INTO timetable\.parameter`). + WithArgs(anyArgs(3)...). + WillReturnResult(pgxmock.NewResult("INSERT", 1)). + Times(2) + + // Mock builtin-task creation with 1 parameter + mockPool.ExpectQuery(`INSERT INTO timetable\.task`). + WithArgs(anyArgs(10)...). + WillReturnRows(pgxmock.NewRows([]string{"task_id"}).AddRow(3)) + mockPool.ExpectExec(`INSERT INTO timetable\.parameter`). + WithArgs(anyArgs(3)...). + WillReturnResult(pgxmock.NewResult("INSERT", 1)) + + err := mockpge.LoadYamlChains(context.Background(), tmpfile, false) require.NoError(t, err) - assert.Equal(t, 5, paramCount) // sql-task: 2 params, program-task: 2 params, builtin-task: 1 param = 5 total + assert.NoError(t, mockPool.ExpectationsWereMet()) }) } -// Helper function to create string pointer -func stringPtr(s string) *string { - return &s -} - func TestNullStringFunction(t *testing.T) { // Testing nullString indirectly through database operations - container, cleanup := testutils.SetupPostgresContainer(t) - defer cleanup() - - ctx := context.Background() - pge := container.Engine + initmockdb(t) + defer mockPool.Close() + mockpge := pgengine.NewDB(mockPool, "pgengine_unit_test") t.Run("All null fields", func(t *testing.T) { yamlContent := `chains: @@ -1120,28 +1087,21 @@ func TestNullStringFunction(t *testing.T) { tmpfile := createTempYamlFile(t, yamlContent) defer removeTempFile(t, tmpfile) - err := pge.LoadYamlChains(ctx, tmpfile, false) - require.NoError(t, err) - - // Verify NULL values in chain table - var clientName, onError any - err = pge.ConfigDb.QueryRow(ctx, - "SELECT client_name, on_error FROM timetable.chain WHERE chain_name = $1", - "all-nulls-chain").Scan(&clientName, &onError) - require.NoError(t, err) - assert.Nil(t, clientName) - assert.Nil(t, onError) - - // Verify NULL values in task table - var runAs, dbConn any - err = pge.ConfigDb.QueryRow(ctx, - `SELECT run_as, database_connection FROM timetable.task t - JOIN timetable.chain c ON t.chain_id = c.chain_id - WHERE c.chain_name = $1`, - "all-nulls-chain").Scan(&runAs, &dbConn) + // Mock chain creation with NULL fields + mockPool.ExpectQuery("SELECT EXISTS"). + WithArgs("all-nulls-chain"). + WillReturnRows(pgxmock.NewRows([]string{"exists"}).AddRow(false)) + mockPool.ExpectQuery(`INSERT INTO timetable\.chain`). + WithArgs(anyArgs(9)...). + WillReturnRows(pgxmock.NewRows([]string{"chain_id"}).AddRow(1)) + // Mock task creation with NULL fields + mockPool.ExpectQuery(`INSERT INTO timetable\.task`). + WithArgs(anyArgs(10)...). + WillReturnRows(pgxmock.NewRows([]string{"task_id"}).AddRow(1)) + + err := mockpge.LoadYamlChains(context.Background(), tmpfile, false) require.NoError(t, err) - assert.Nil(t, runAs) - assert.Nil(t, dbConn) + assert.NoError(t, mockPool.ExpectationsWereMet()) }) t.Run("Mixed null and non-null fields", func(t *testing.T) { @@ -1160,29 +1120,101 @@ func TestNullStringFunction(t *testing.T) { tmpfile := createTempYamlFile(t, yamlContent) defer removeTempFile(t, tmpfile) - err := pge.LoadYamlChains(ctx, tmpfile, false) + // Mock chain creation with mixed NULL/non-NULL fields + mockPool.ExpectQuery("SELECT EXISTS"). + WithArgs("mixed-nulls-chain"). + WillReturnRows(pgxmock.NewRows([]string{"exists"}).AddRow(false)) + mockPool.ExpectQuery(`INSERT INTO timetable\.chain`). + WithArgs(anyArgs(9)...). + WillReturnRows(pgxmock.NewRows([]string{"chain_id"}).AddRow(1)) + // Mock task creation with mixed NULL/non-NULL fields + mockPool.ExpectQuery(`INSERT INTO timetable\.task`). + WithArgs(anyArgs(10)...). + WillReturnRows(pgxmock.NewRows([]string{"task_id"}).AddRow(1)) + + err := mockpge.LoadYamlChains(context.Background(), tmpfile, false) require.NoError(t, err) + assert.NoError(t, mockPool.ExpectationsWereMet()) + }) +} - // Verify mixed values in chain table - var clientName, onError any - err = pge.ConfigDb.QueryRow(ctx, - "SELECT client_name, on_error FROM timetable.chain WHERE chain_name = $1", - "mixed-nulls-chain").Scan(&clientName, &onError) - require.NoError(t, err) - require.NotNil(t, clientName) - assert.Equal(t, "real-client", clientName) - assert.Nil(t, onError) +func TestCreateChainFromYamlErrors(t *testing.T) { + initmockdb(t) + defer mockPool.Close() + mockpge := pgengine.NewDB(mockPool, "pgengine_unit_test") - // Verify mixed values in task table - var runAs, dbConn any - err = pge.ConfigDb.QueryRow(ctx, - `SELECT run_as, database_connection FROM timetable.task t - JOIN timetable.chain c ON t.chain_id = c.chain_id - WHERE c.chain_name = $1`, - "mixed-nulls-chain").Scan(&runAs, &dbConn) - require.NoError(t, err) - require.NotNil(t, runAs) - assert.Equal(t, "real-user", runAs) - assert.Nil(t, dbConn) + t.Run("Database error during chain creation", func(t *testing.T) { + mockPool.ExpectQuery(`INSERT INTO timetable.chain`). + WithArgs(anyArgs(9)...). + WillReturnError(fmt.Errorf("simulated DB error")) + _, err := mockpge.CreateChainFromYaml(ctx, &pgengine.YamlChain{}) + assert.Error(t, err) + assert.NoError(t, mockPool.ExpectationsWereMet()) + }) + + t.Run("Database error during task creation", func(t *testing.T) { + mockPool.ExpectQuery(`INSERT INTO timetable.chain`). + WithArgs(anyArgs(9)...). + WillReturnRows(pgxmock.NewRows([]string{"chain_id"}).AddRow(1)) + mockPool.ExpectQuery(`INSERT INTO timetable.task`). + WithArgs(anyArgs(10)...). + WillReturnError(fmt.Errorf("simulated DB error on task")) + + _, err := mockpge.CreateChainFromYaml(ctx, &pgengine.YamlChain{ + Chain: pgengine.Chain{ChainName: "test-chain"}, + Schedule: "0 0 * * *", + Tasks: []pgengine.YamlTask{ + {ChainTask: pgengine.ChainTask{Command: "SELECT 1", Kind: "SQL"}}, + }, + }) + assert.Error(t, err) + assert.NoError(t, mockPool.ExpectationsWereMet()) + }) + + t.Run("Database error during parameter unmarshalling", func(t *testing.T) { + mockPool.ExpectQuery(`INSERT INTO timetable.chain`). + WithArgs(anyArgs(9)...). + WillReturnRows(pgxmock.NewRows([]string{"chain_id"}).AddRow(1)) + mockPool.ExpectQuery(`INSERT INTO timetable.task`). + WithArgs(anyArgs(10)...). + WillReturnRows(pgxmock.NewRows([]string{"task_id"}).AddRow(1)) + + _, err := mockpge.CreateChainFromYaml(ctx, &pgengine.YamlChain{ + Chain: pgengine.Chain{ChainName: "test-chain"}, + Schedule: "0 0 * * *", + Tasks: []pgengine.YamlTask{ + { + ChainTask: pgengine.ChainTask{Command: "SELECT 1", Kind: "SQL"}, + Parameters: []any{func() {}}, // functions cannot be marshalled to JSON + }, + }, + }) + assert.Error(t, err) + assert.NoError(t, mockPool.ExpectationsWereMet()) + }) + + t.Run("Database error during parameter creation", func(t *testing.T) { + mockPool.ExpectQuery(`INSERT INTO timetable.chain`). + WithArgs(anyArgs(9)...). + WillReturnRows(pgxmock.NewRows([]string{"chain_id"}).AddRow(1)) + mockPool.ExpectQuery(`INSERT INTO timetable.task`). + WithArgs(anyArgs(10)...). + WillReturnRows(pgxmock.NewRows([]string{"task_id"}).AddRow(1)) + mockPool.ExpectExec(`INSERT INTO timetable.parameter`). + WithArgs(anyArgs(3)...). + WillReturnError(fmt.Errorf("simulated DB error on parameter")) + + _, err := mockpge.CreateChainFromYaml(ctx, &pgengine.YamlChain{ + Chain: pgengine.Chain{ChainName: "test-chain"}, + Schedule: "0 0 * * *", + Tasks: []pgengine.YamlTask{ + { + ChainTask: pgengine.ChainTask{Command: "SELECT 1", Kind: "SQL"}, + Parameters: []any{"foo"}, + }, + }, + }) + assert.Error(t, err) + assert.NoError(t, mockPool.ExpectationsWereMet()) }) }