diff --git a/.github/typos.toml b/.github/typos.toml index fc4c943..aac65b6 100644 --- a/.github/typos.toml +++ b/.github/typos.toml @@ -1,2 +1,2 @@ [files] -extend-exclude = ["pkg/tui/view.go", "pkg/tui/logs.go"] +extend-exclude = ["pkg/tui/view.go", "pkg/tui/logs.go", "pkg/seeddata/discovery.go"] diff --git a/README.md b/README.md index 74028f6..c9d8186 100644 --- a/README.md +++ b/README.md @@ -226,3 +226,36 @@ models: ``` See [`.cbt-overrides.example.yaml`](.cbt-overrides.example.yaml) for more examples. + +## xatu-cbt Test Data + +Generate seed data parquet files for xatu-cbt tests: + +```bash +# Interactive mode +xcli lab xatu-cbt generate-seed-data + +# Scripted mode +xcli lab xatu-cbt generate-seed-data \ + --model consensus_engine_api_new_payload \ + --network mainnet \ + --spec fusaka \ + --range-column slot \ + --from 1000000 \ + --to 1001000 \ + --filter "status = VALID" \ + --upload +``` + +Requires hybrid mode (`xcli lab mode hybrid`) for external ClickHouse access. + +### S3 Upload + +Set R2/S3 credentials for upload: + +```bash +export AWS_ACCESS_KEY_ID="your-access-key" +export AWS_SECRET_ACCESS_KEY="your-secret-key" +``` + +Defaults to ethpandaops R2 bucket. Override with `S3_ENDPOINT` and `S3_BUCKET` env vars. diff --git a/go.mod b/go.mod index a8c57fb..4d4f45b 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,10 @@ require ( dario.cat/mergo v1.0.2 github.com/ClickHouse/clickhouse-go/v2 v2.41.0 github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d + github.com/aws/aws-sdk-go-v2 v1.41.0 + github.com/aws/aws-sdk-go-v2/config v1.32.5 + github.com/aws/aws-sdk-go-v2/service/s3 v1.93.2 + github.com/aws/smithy-go v1.24.0 github.com/charmbracelet/bubbletea v1.3.10 github.com/charmbracelet/lipgloss v1.1.0 github.com/docker/docker v28.5.2+incompatible @@ -27,6 +31,21 @@ require ( github.com/ClickHouse/ch-go v0.69.0 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect github.com/andybalholm/brotli v1.2.0 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.4 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.19.5 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.16 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.16 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.7 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.16 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.16 // indirect + github.com/aws/aws-sdk-go-v2/service/signin v1.0.4 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.30.7 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.12 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.41.5 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect github.com/charmbracelet/x/ansi v0.10.1 // indirect diff --git a/go.sum b/go.sum index ea33af8..d395fac 100644 --- a/go.sum +++ b/go.sum @@ -32,6 +32,44 @@ github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d/go.mod h1:asat6 github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/atomicgo/cursor v0.0.1/go.mod h1:cBON2QmmrysudxNBFthvMtN32r3jxVRIvzkUiF/RuIk= +github.com/aws/aws-sdk-go-v2 v1.41.0 h1:tNvqh1s+v0vFYdA1xq0aOJH+Y5cRyZ5upu6roPgPKd4= +github.com/aws/aws-sdk-go-v2 v1.41.0/go.mod h1:MayyLB8y+buD9hZqkCW3kX1AKq07Y5pXxtgB+rRFhz0= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.4 h1:489krEF9xIGkOaaX3CE/Be2uWjiXrkCH6gUX+bZA/BU= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.4/go.mod h1:IOAPF6oT9KCsceNTvvYMNHy0+kMF8akOjeDvPENWxp4= +github.com/aws/aws-sdk-go-v2/config v1.32.5 h1:pz3duhAfUgnxbtVhIK39PGF/AHYyrzGEyRD9Og0QrE8= +github.com/aws/aws-sdk-go-v2/config v1.32.5/go.mod h1:xmDjzSUs/d0BB7ClzYPAZMmgQdrodNjPPhd6bGASwoE= +github.com/aws/aws-sdk-go-v2/credentials v1.19.5 h1:xMo63RlqP3ZZydpJDMBsH9uJ10hgHYfQFIk1cHDXrR4= +github.com/aws/aws-sdk-go-v2/credentials v1.19.5/go.mod h1:hhbH6oRcou+LpXfA/0vPElh/e0M3aFeOblE1sssAAEk= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16 h1:80+uETIWS1BqjnN9uJ0dBUaETh+P1XwFy5vwHwK5r9k= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16/go.mod h1:wOOsYuxYuB/7FlnVtzeBYRcjSRtQpAW0hCP7tIULMwo= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16 h1:rgGwPzb82iBYSvHMHXc8h9mRoOUBZIGFgKb9qniaZZc= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16/go.mod h1:L/UxsGeKpGoIj6DxfhOWHWQ/kGKcd4I1VncE4++IyKA= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.16 h1:1jtGzuV7c82xnqOVfx2F0xmJcOw5374L7N6juGW6x6U= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.16/go.mod h1:M2E5OQf+XLe+SZGmmpaI2yy+J326aFf6/+54PoxSANc= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.16 h1:CjMzUs78RDDv4ROu3JnJn/Ig1r6ZD7/T2DXLLRpejic= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.16/go.mod h1:uVW4OLBqbJXSHJYA9svT9BluSvvwbzLQ2Crf6UPzR3c= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4 h1:0ryTNEdJbzUCEWkVXEXoqlXV72J5keC1GvILMOuD00E= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4/go.mod h1:HQ4qwNZh32C3CBeO6iJLQlgtMzqeG17ziAA/3KDJFow= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.7 h1:DIBqIrJ7hv+e4CmIk2z3pyKT+3B6qVMgRsawHiR3qso= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.7/go.mod h1:vLm00xmBke75UmpNvOcZQ/Q30ZFjbczeLFqGx5urmGo= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.16 h1:oHjJHeUy0ImIV0bsrX0X91GkV5nJAyv1l1CC9lnO0TI= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.16/go.mod h1:iRSNGgOYmiYwSCXxXaKb9HfOEj40+oTKn8pTxMlYkRM= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.16 h1:NSbvS17MlI2lurYgXnCOLvCFX38sBW4eiVER7+kkgsU= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.16/go.mod h1:SwT8Tmqd4sA6G1qaGdzWCJN99bUmPGHfRwwq3G5Qb+A= +github.com/aws/aws-sdk-go-v2/service/s3 v1.93.2 h1:U3ygWUhCpiSPYSHOrRhb3gOl9T5Y3kB8k5Vjs//57bE= +github.com/aws/aws-sdk-go-v2/service/s3 v1.93.2/go.mod h1:79S2BdqCJpScXZA2y+cpZuocWsjGjJINyXnOsf5DTz8= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.4 h1:HpI7aMmJ+mm1wkSHIA2t5EaFFv5EFYXePW30p1EIrbQ= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.4/go.mod h1:C5RdGMYGlfM0gYq/tifqgn4EbyX99V15P2V3R+VHbQU= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.7 h1:eYnlt6QxnFINKzwxP5/Ucs1vkG7VT3Iezmvfgc2waUw= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.7/go.mod h1:+fWt2UHSb4kS7Pu8y+BMBvJF0EWx+4H0hzNwtDNRTrg= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.12 h1:AHDr0DaHIAo8c9t1emrzAlVDFp+iMMKnPdYy6XO4MCE= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.12/go.mod h1:GQ73XawFFiWxyWXMHWfhiomvP3tXtdNar/fi8z18sx0= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.5 h1:SciGFVNZ4mHdm7gpD1dgZYnCuVdX1s+lFTg4+4DOy70= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.5/go.mod h1:iW40X4QBmUxdP+fZNOpfmkdMZqsovezbAeO+Ubiv2pk= +github.com/aws/smithy-go v1.24.0 h1:LpilSUItNPFr1eY85RYgTIg5eIEPtvFbskaFcmmIUnk= +github.com/aws/smithy-go v1.24.0/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM= diff --git a/pkg/commands/lab.go b/pkg/commands/lab.go index f43f16f..9679315 100644 --- a/pkg/commands/lab.go +++ b/pkg/commands/lab.go @@ -56,6 +56,7 @@ Use 'xcli lab [command] --help' for more information about a command.`, cmd.AddCommand(NewLabTUICommand(log, configPath)) cmd.AddCommand(NewLabDiagnoseCommand(log, configPath)) cmd.AddCommand(NewLabReleaseCommand(log, configPath)) + cmd.AddCommand(NewLabXatuCBTCommand(log, configPath)) return cmd } diff --git a/pkg/commands/lab_xatu_cbt.go b/pkg/commands/lab_xatu_cbt.go new file mode 100644 index 0000000..fed9099 --- /dev/null +++ b/pkg/commands/lab_xatu_cbt.go @@ -0,0 +1,31 @@ +package commands + +import ( + "github.com/sirupsen/logrus" + "github.com/spf13/cobra" +) + +// NewLabXatuCBTCommand creates the lab xatu-cbt command namespace. +func NewLabXatuCBTCommand(log logrus.FieldLogger, configPath string) *cobra.Command { + cmd := &cobra.Command{ + Use: "xatu-cbt", + Short: "Xatu-CBT related commands", + Long: `Commands for working with xatu-cbt, including generating seed data +for tests. + +Common workflows: + 1. Generate seed data for a single external model: + xcli lab xatu-cbt generate-seed-data + + 2. Generate test YAML for transformation models (auto-resolves dependencies): + xcli lab xatu-cbt generate-transformation-test + +Use 'xcli lab xatu-cbt [command] --help' for more information about a command.`, + } + + // Add xatu-cbt subcommands + cmd.AddCommand(NewLabXatuCBTGenerateSeedDataCommand(log, configPath)) + cmd.AddCommand(NewLabXatuCBTGenerateTransformationTestCommand(log, configPath)) + + return cmd +} diff --git a/pkg/commands/lab_xatu_cbt_generate_seed_data.go b/pkg/commands/lab_xatu_cbt_generate_seed_data.go new file mode 100644 index 0000000..bb1705c --- /dev/null +++ b/pkg/commands/lab_xatu_cbt_generate_seed_data.go @@ -0,0 +1,670 @@ +package commands + +import ( + "context" + "fmt" + "os" + "path/filepath" + + "github.com/ethpandaops/xcli/pkg/config" + "github.com/ethpandaops/xcli/pkg/constants" + "github.com/ethpandaops/xcli/pkg/seeddata" + "github.com/ethpandaops/xcli/pkg/ui" + "github.com/sirupsen/logrus" + "github.com/spf13/cobra" +) + +const ( + defaultRowLimit = 10000 +) + +// NewLabXatuCBTGenerateSeedDataCommand creates the lab xatu-cbt generate-seed-data command. +func NewLabXatuCBTGenerateSeedDataCommand(log logrus.FieldLogger, configPath string) *cobra.Command { + var ( + model string + network string + spec string + rangeColumn string + from string + to string + filters []string + limit int + output string + upload bool + noSanitizeIPs bool + ) + + cmd := &cobra.Command{ + Use: "generate-seed-data", + Short: "Generate seed data parquet files for xatu-cbt tests", + Long: `Generate seed data parquet files for xatu-cbt tests by extracting data +from the external ClickHouse cluster. + +This command requires hybrid mode to be enabled, as it needs access to +the external ClickHouse cluster containing production xatu data. + +Interactive mode (prompts for all required inputs): + xcli lab xatu-cbt generate-seed-data + +Scripted mode (all flags provided): + xcli lab xatu-cbt generate-seed-data \ + --model beacon_api_eth_v1_events_block \ + --network mainnet \ + --spec pectra \ + --range-column slot \ + --from 1000000 \ + --to 1001000 \ + --filter "status = VALID" \ + --filter "proposer_index > 100" \ + --upload + +The command outputs a test YAML template that can be used directly in xatu-cbt tests. + +S3 Upload Configuration (defaults to Cloudflare R2): + AWS_ACCESS_KEY_ID R2 API token Access Key ID + AWS_SECRET_ACCESS_KEY R2 API token Secret Access Key + S3_ENDPOINT Override endpoint (default: ethpandaops R2) + S3_BUCKET Override bucket (default: ethpandaops-platform-production-public)`, + RunE: func(cmd *cobra.Command, args []string) error { + return runGenerateSeedData(cmd.Context(), log, configPath, + model, network, spec, rangeColumn, from, to, filters, limit, output, upload, !noSanitizeIPs) + }, + } + + cmd.Flags().StringVar(&model, "model", "", "Table name from xatu-cbt external models") + cmd.Flags().StringVar(&network, "network", "", "Network name (mainnet, sepolia, etc.)") + cmd.Flags().StringVar(&spec, "spec", "", "Fork spec (pectra, fusaka, etc.)") + cmd.Flags().StringVar(&rangeColumn, "range-column", "", "Column to filter on (e.g., slot, epoch)") + cmd.Flags().StringVar(&from, "from", "", "Range start value") + cmd.Flags().StringVar(&to, "to", "", "Range end value") + cmd.Flags().StringArrayVar(&filters, "filter", nil, "Additional filter (format: 'column operator value', e.g., 'status = VALID')") + cmd.Flags().IntVar(&limit, "limit", defaultRowLimit, "Max rows (0 = unlimited)") + cmd.Flags().StringVarP(&output, "output", "o", "", "Output file path (default: ./{model}.parquet)") + cmd.Flags().BoolVar(&upload, "upload", false, "Upload to S3 after generation") + cmd.Flags().BoolVar(&noSanitizeIPs, "no-sanitize-ips", false, "Disable IP address sanitization (IPs are sanitized by default)") + + return cmd +} + +//nolint:funlen,cyclop,gocyclo,gocognit // Command handler with interactive flow +func runGenerateSeedData( + ctx context.Context, + log logrus.FieldLogger, + configPath string, + model, network, spec, rangeColumn, from, to string, + filterStrings []string, + limit int, + output string, + upload bool, + sanitizeIPs bool, +) error { + // Load configuration + labCfg, _, err := config.LoadLabConfig(configPath) + if err != nil { + return fmt.Errorf("failed to load config: %w", err) + } + + // Validate hybrid mode + if labCfg.Mode != constants.ModeHybrid { + return fmt.Errorf("this command requires hybrid mode (current mode: %s)\n"+ + "Run 'xcli lab mode hybrid' to switch to hybrid mode", labCfg.Mode) + } + + // Create generator + gen := seeddata.NewGenerator(log, labCfg) + + // Interactive mode: prompt for missing values + var promptErr error + + if model == "" { + model, promptErr = promptForModel(gen) + if promptErr != nil { + return promptErr + } + } else { + // Validate provided model + if validateErr := gen.ValidateModel(model); validateErr != nil { + return validateErr + } + } + + if network == "" { + network, promptErr = promptForNetwork(labCfg) + if promptErr != nil { + return promptErr + } + } + + if spec == "" { + spec, promptErr = promptForSpec() + if promptErr != nil { + return promptErr + } + } + + // Prompt for range (optional) + if rangeColumn == "" { + rangeColumn, from, to, promptErr = promptForRange() + if promptErr != nil { + return promptErr + } + } + + // Parse filter strings into Filter structs + filters, parseErr := parseFilters(filterStrings) + if parseErr != nil { + return parseErr + } + + // Prompt for additional filters (interactive mode) + if len(filterStrings) == 0 { + additionalFilters, filterErr := promptForFilters() + if filterErr != nil { + return filterErr + } + + filters = append(filters, additionalFilters...) + } + + // Prompt for limit if not specified via flag + if limit == defaultRowLimit { + limit, promptErr = promptForLimit() + if promptErr != nil { + return promptErr + } + } + + // Prompt for upload if not specified + if !upload { + upload, promptErr = promptForUpload() + if promptErr != nil { + return promptErr + } + } + + // Prompt for S3 filename if upload is enabled + var s3Filename string + + if upload { + s3Filename, promptErr = promptForS3Filename(model) + if promptErr != nil { + return promptErr + } + + s3Spinner := ui.NewSpinner("Checking S3 access") + + uploader, uploaderErr := seeddata.NewS3Uploader(ctx, log) + if uploaderErr != nil { + s3Spinner.Fail("Failed to initialize S3 client") + + return fmt.Errorf("failed to create S3 uploader: %w", uploaderErr) + } + + if accessErr := uploader.CheckAccess(ctx); accessErr != nil { + s3Spinner.Fail("S3 access check failed") + + return fmt.Errorf("S3 preflight check failed: %w", accessErr) + } + + s3Spinner.Success("S3 access verified") + + // Check if object already exists + exists, existsErr := uploader.ObjectExists(ctx, network, spec, s3Filename) + if existsErr != nil { + ui.Warning(fmt.Sprintf("Could not check if file exists: %v", existsErr)) + } else if exists { + existingURL := uploader.GetPublicURL(network, spec, s3Filename) + ui.Warning(fmt.Sprintf("File already exists: %s", existingURL)) + + overwrite, confirmErr := ui.Confirm("Overwrite existing file?") + if confirmErr != nil { + return confirmErr + } + + if !overwrite { + return fmt.Errorf("upload cancelled - file already exists") + } + } + } + + // Set default output path + if output == "" { + output = fmt.Sprintf("./%s.parquet", model) + } + + // Generate salt for IP sanitization if enabled + var salt string + + if sanitizeIPs { + var saltErr error + + salt, saltErr = seeddata.GenerateSalt() + if saltErr != nil { + return fmt.Errorf("failed to generate salt for IP sanitization: %w", saltErr) + } + } + + // Generate seed data + ui.Header("Generating seed data") + + spinner := ui.NewSpinner(fmt.Sprintf("Extracting data for %s", model)) + + result, err := gen.Generate(ctx, seeddata.GenerateOptions{ + Model: model, + Network: network, + Spec: spec, + RangeColumn: rangeColumn, + From: from, + To: to, + Filters: filters, + Limit: limit, + OutputPath: output, + SanitizeIPs: sanitizeIPs, + Salt: salt, + }) + if err != nil { + spinner.Fail("Failed to generate seed data") + + return fmt.Errorf("failed to generate seed data: %w", err) + } + + spinner.Success(fmt.Sprintf("Written to: %s (%s)", result.OutputPath, formatFileSize(result.FileSize))) + + // Display sanitized columns if any + if len(result.SanitizedColumns) > 0 { + ui.Info(fmt.Sprintf("Sanitized IP columns: %v", result.SanitizedColumns)) + } + + // Upload to S3 if requested + var publicURL string + + if upload { + publicURL, err = uploadToS3(ctx, log, result.OutputPath, network, spec, model, s3Filename) + if err != nil { + return err + } + + // Clean up local file after successful upload + if removeErr := os.Remove(result.OutputPath); removeErr != nil { + ui.Warning(fmt.Sprintf("Could not remove local file: %v", removeErr)) + } else { + ui.Info(fmt.Sprintf("Cleaned up local file: %s", result.OutputPath)) + } + } else { + // Use placeholder URL for template + publicURL = fmt.Sprintf("https://%s/%s/%s/%s/%s.parquet", + seeddata.DefaultS3PublicDomain, seeddata.DefaultS3Prefix, network, spec, model) + } + + // Generate test YAML + yamlFilename := s3Filename + if yamlFilename == "" { + yamlFilename = model + } + + yamlContent, err := seeddata.GenerateTestYAML(seeddata.TemplateData{ + Model: model, + Network: network, + Spec: spec, + URL: publicURL, + RowCount: estimateRowCount(result.FileSize), + }) + if err != nil { + return fmt.Errorf("failed to generate YAML template: %w", err) + } + + // Prompt to write YAML to xatu-cbt + writeYAML, writeErr := ui.Confirm("Write test YAML to xatu-cbt?") + if writeErr != nil { + return writeErr + } + + if writeYAML { + yamlPath := fmt.Sprintf("%s/tests/%s/%s/models/%s.yaml", + labCfg.Repos.XatuCBT, network, spec, yamlFilename) + + if yamlWriteErr := writeTestYAML(yamlPath, yamlContent); yamlWriteErr != nil { + return yamlWriteErr + } + } + + // Display test YAML + ui.Blank() + ui.Header("Test YAML") + + fmt.Println(yamlContent) + + if !upload { + ui.Blank() + ui.Warning("File was not uploaded. Update the URL in the YAML after uploading manually.") + } + + return nil +} + +func promptForModel(gen *seeddata.Generator) (string, error) { + models, err := gen.ListExternalModels() + if err != nil { + return "", fmt.Errorf("failed to list models: %w", err) + } + + options := make([]ui.SelectOption, 0, len(models)) + for _, m := range models { + options = append(options, ui.SelectOption{ + Label: m, + Value: m, + }) + } + + return ui.Select("Select a model", options) +} + +func promptForNetwork(labCfg *config.LabConfig) (string, error) { + options := make([]ui.SelectOption, 0, len(labCfg.Networks)) + + for _, net := range labCfg.Networks { + options = append(options, ui.SelectOption{ + Label: net.Name, + Value: net.Name, + }) + } + + return ui.Select("Select network", options) +} + +func promptForSpec() (string, error) { + options := []ui.SelectOption{ + {Label: "pectra", Value: "pectra"}, + {Label: "fusaka", Value: "fusaka"}, + } + + return ui.Select("Select spec", options) +} + +func promptForRange() (column, from, to string, err error) { + column, err = ui.TextInput("Filter by column (leave empty to skip)", "") + if err != nil { + return "", "", "", err + } + + if column == "" { + return "", "", "", nil + } + + from, err = ui.TextInput("From value", "") + if err != nil { + return "", "", "", err + } + + to, err = ui.TextInput("To value", "") + if err != nil { + return "", "", "", err + } + + return column, from, to, nil +} + +func promptForLimit() (int, error) { + limitStr, err := ui.TextInput(fmt.Sprintf("Row limit [%d]", defaultRowLimit), "") + if err != nil { + return 0, err + } + + if limitStr == "" { + return defaultRowLimit, nil + } + + var limit int + + _, err = fmt.Sscanf(limitStr, "%d", &limit) + if err != nil { + return 0, fmt.Errorf("invalid limit: %w", err) + } + + return limit, nil +} + +func promptForUpload() (bool, error) { + return ui.Confirm("Upload to S3?") +} + +func writeTestYAML(path, content string) error { + spinner := ui.NewSpinner(fmt.Sprintf("Writing YAML to %s", path)) + + // Ensure directory exists + dir := filepath.Dir(path) + if mkdirErr := os.MkdirAll(dir, 0755); mkdirErr != nil { + spinner.Fail("Failed to create directory") + + return fmt.Errorf("failed to create directory %s: %w", dir, mkdirErr) + } + + // Write the file (0644 is intentional - this is a config file to be committed to git) + if writeErr := os.WriteFile(path, []byte(content), 0644); writeErr != nil { //nolint:gosec // G306: config file needs to be readable + spinner.Fail("Failed to write YAML file") + + return fmt.Errorf("failed to write YAML to %s: %w", path, writeErr) + } + + spinner.Success(fmt.Sprintf("Written to: %s", path)) + + return nil +} + +func promptForS3Filename(defaultName string) (string, error) { + filename, err := ui.TextInput(fmt.Sprintf("S3 filename (without .parquet) [%s]", defaultName), "") + if err != nil { + return "", err + } + + if filename == "" { + return defaultName, nil + } + + return filename, nil +} + +func uploadToS3(ctx context.Context, log logrus.FieldLogger, localPath, network, spec, model, filename string) (string, error) { + spinner := ui.NewSpinner("Uploading to S3") + + uploader, err := seeddata.NewS3Uploader(ctx, log) + if err != nil { + spinner.Fail("Failed to create S3 uploader") + + return "", fmt.Errorf("failed to create S3 uploader: %w", err) + } + + result, err := uploader.Upload(ctx, seeddata.UploadOptions{ + LocalPath: localPath, + Network: network, + Spec: spec, + Model: model, + Filename: filename, + }) + if err != nil { + spinner.Fail("Failed to upload to S3") + + return "", fmt.Errorf("failed to upload to S3: %w", err) + } + + spinner.Success(fmt.Sprintf("Uploaded to: %s", result.PublicURL)) + + return result.PublicURL, nil +} + +func formatFileSize(bytes int64) string { + const unit = 1024 + + if bytes < unit { + return fmt.Sprintf("%d B", bytes) + } + + div, exp := int64(unit), 0 + for n := bytes / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + + return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp]) +} + +func estimateRowCount(fileSize int64) int64 { + // Rough estimate: average 100 bytes per row in compressed parquet + // This is a placeholder - in reality we'd need to read the parquet metadata + return fileSize / 100 +} + +// parseFilters parses filter strings into Filter structs. +// Format: "column operator value" (e.g., "status = VALID"). +func parseFilters(filterStrings []string) ([]seeddata.Filter, error) { + if len(filterStrings) == 0 { + return nil, nil + } + + filters := make([]seeddata.Filter, 0, len(filterStrings)) + + for _, s := range filterStrings { + filter, err := parseFilterString(s) + if err != nil { + return nil, err + } + + filters = append(filters, filter) + } + + return filters, nil +} + +// parseFilterString parses a single filter string. +// Supports operators: =, !=, <>, >, <, >=, <=, LIKE, NOT LIKE, IN, NOT IN. +func parseFilterString(s string) (seeddata.Filter, error) { + // List of operators to check (longer ones first to avoid partial matches) + operators := []string{ + "NOT LIKE", "NOT IN", + ">=", "<=", "!=", "<>", + "LIKE", "IN", + "=", ">", "<", + } + + for _, op := range operators { + idx := findOperatorIndex(s, op) + if idx != -1 { + column := trimSpace(s[:idx]) + value := trimSpace(s[idx+len(op):]) + + if column == "" || value == "" { + return seeddata.Filter{}, fmt.Errorf("invalid filter format: %q (expected 'column %s value')", s, op) + } + + return seeddata.Filter{ + Column: column, + Operator: op, + Value: value, + }, nil + } + } + + return seeddata.Filter{}, fmt.Errorf("invalid filter format: %q (no valid operator found)", s) +} + +// findOperatorIndex finds the index of an operator in a string, case-insensitive. +func findOperatorIndex(s, op string) int { + upper := toUpper(s) + opUpper := toUpper(op) + + for i := 0; i <= len(upper)-len(opUpper); i++ { + if upper[i:i+len(opUpper)] == opUpper { + return i + } + } + + return -1 +} + +func trimSpace(s string) string { + start := 0 + end := len(s) + + for start < end && (s[start] == ' ' || s[start] == '\t') { + start++ + } + + for end > start && (s[end-1] == ' ' || s[end-1] == '\t') { + end-- + } + + return s[start:end] +} + +func toUpper(s string) string { + b := make([]byte, len(s)) + + for i := range len(s) { + c := s[i] + if c >= 'a' && c <= 'z' { + c -= 'a' - 'A' + } + + b[i] = c + } + + return string(b) +} + +// promptForFilters prompts the user to add additional filters interactively. +func promptForFilters() ([]seeddata.Filter, error) { + var filters []seeddata.Filter + + for { + addMore, err := ui.Confirm("Add a filter?") + if err != nil { + return nil, err + } + + if !addMore { + break + } + + column, err := ui.TextInput("Column name", "") + if err != nil { + return nil, err + } + + if column == "" { + continue + } + + operator, err := promptForOperator() + if err != nil { + return nil, err + } + + value, err := ui.TextInput("Value", "") + if err != nil { + return nil, err + } + + filters = append(filters, seeddata.Filter{ + Column: column, + Operator: operator, + Value: value, + }) + } + + return filters, nil +} + +func promptForOperator() (string, error) { + options := []ui.SelectOption{ + {Label: "= (equals)", Value: "="}, + {Label: "!= (not equals)", Value: "!="}, + {Label: "> (greater than)", Value: ">"}, + {Label: "< (less than)", Value: "<"}, + {Label: ">= (greater or equal)", Value: ">="}, + {Label: "<= (less or equal)", Value: "<="}, + {Label: "LIKE (pattern match)", Value: "LIKE"}, + {Label: "IN (in list)", Value: "IN"}, + } + + return ui.Select("Select operator", options) +} diff --git a/pkg/commands/lab_xatu_cbt_generate_transformation.go b/pkg/commands/lab_xatu_cbt_generate_transformation.go new file mode 100644 index 0000000..dccf32b --- /dev/null +++ b/pkg/commands/lab_xatu_cbt_generate_transformation.go @@ -0,0 +1,886 @@ +package commands + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/ethpandaops/xcli/pkg/config" + "github.com/ethpandaops/xcli/pkg/constants" + "github.com/ethpandaops/xcli/pkg/seeddata" + "github.com/ethpandaops/xcli/pkg/ui" + "github.com/sirupsen/logrus" + "github.com/spf13/cobra" +) + +// NewLabXatuCBTGenerateTransformationTestCommand creates the command. +func NewLabXatuCBTGenerateTransformationTestCommand(log logrus.FieldLogger, configPath string) *cobra.Command { + var ( + model string + network string + spec string + rangeColumn string + from string + to string + limit int + upload bool + aiAssertions bool + skipExisting bool + noSanitizeIPs bool + duration string + ) + + cmd := &cobra.Command{ + Use: "generate-transformation-test", + Short: "Generate test YAML for transformation models", + Long: `Generate complete test YAML files for xatu-cbt transformation models. + +This command: +1. Resolves the full dependency tree for a transformation model +2. Identifies all external model dependencies (leaf nodes) +3. Queries external ClickHouse for available data ranges +4. Finds the intersecting range across all dependencies +5. Generates seed data parquet files for all external models +6. Optionally uses Claude to generate meaningful assertions +7. Writes the complete test YAML to xatu-cbt + +This command requires hybrid mode to be enabled. + +Interactive mode: + xcli lab xatu-cbt generate-transformation-test + +Scripted mode: + xcli lab xatu-cbt generate-transformation-test \ + --model fct_data_column_availability_by_slot \ + --network sepolia \ + --spec fusaka \ + --range-column slot_start_date_time \ + --from "2025-10-27 00:26:00" \ + --to "2025-10-27 00:30:00" \ + --upload \ + --ai-assertions + +S3 Upload Configuration (defaults to Cloudflare R2): + AWS_ACCESS_KEY_ID R2 API token Access Key ID + AWS_SECRET_ACCESS_KEY R2 API token Secret Access Key + S3_ENDPOINT Override endpoint (default: ethpandaops R2) + S3_BUCKET Override bucket (default: ethpandaops-platform-production-public)`, + RunE: func(cmd *cobra.Command, args []string) error { + return runGenerateTransformationTest(cmd.Context(), log, configPath, + model, network, spec, rangeColumn, from, to, limit, upload, aiAssertions, skipExisting, !noSanitizeIPs, duration) + }, + } + + cmd.Flags().StringVar(&model, "model", "", "Transformation model name") + cmd.Flags().StringVar(&network, "network", "", "Network name (mainnet, sepolia, etc.)") + cmd.Flags().StringVar(&spec, "spec", "", "Fork spec (pectra, fusaka, etc.)") + cmd.Flags().StringVar(&rangeColumn, "range-column", "", "Override detected range column") + cmd.Flags().StringVar(&from, "from", "", "Range start value") + cmd.Flags().StringVar(&to, "to", "", "Range end value") + cmd.Flags().IntVar(&limit, "limit", defaultRowLimit, "Max rows per external model (0 = unlimited)") + cmd.Flags().BoolVar(&upload, "upload", false, "Upload parquets to S3 after generation") + cmd.Flags().BoolVar(&aiAssertions, "ai-assertions", false, "Use Claude to generate assertions") + cmd.Flags().BoolVar(&skipExisting, "skip-existing", false, "Skip generating seed data for existing S3 files") + cmd.Flags().BoolVar(&noSanitizeIPs, "no-sanitize-ips", false, "Disable IP address sanitization (IPs are sanitized by default)") + cmd.Flags().StringVar(&duration, "duration", "", "Time range duration (e.g., 1m, 5m, 10m, 30m)") + + return cmd +} + +//nolint:funlen,cyclop,gocyclo,gocognit // Command handler with interactive flow +func runGenerateTransformationTest( + ctx context.Context, + log logrus.FieldLogger, + configPath string, + model, network, spec, rangeColumn, from, to string, + limit int, + upload, aiAssertions, skipExisting, sanitizeIPs bool, + duration string, +) error { + // Load configuration + labCfg, _, err := config.LoadLabConfig(configPath) + if err != nil { + return fmt.Errorf("failed to load config: %w", err) + } + + // Validate hybrid mode + if labCfg.Mode != constants.ModeHybrid { + return fmt.Errorf("this command requires hybrid mode (current mode: %s)\n"+ + "Run 'xcli lab mode hybrid' to switch to hybrid mode", labCfg.Mode) + } + + // Create generator + gen := seeddata.NewGenerator(log, labCfg) + + // Interactive mode: prompt for model + var promptErr error + + if model == "" { + model, promptErr = promptForTransformationModel(labCfg.Repos.XatuCBT) + if promptErr != nil { + return promptErr + } + } + + // Resolve dependency tree + ui.Header("Resolving dependencies") + + depSpinner := ui.NewSpinner(fmt.Sprintf("Analyzing %s", model)) + + tree, err := seeddata.ResolveDependencyTree(model, labCfg.Repos.XatuCBT, nil) + if err != nil { + depSpinner.Fail("Failed to resolve dependencies") + + return fmt.Errorf("failed to resolve dependencies: %w", err) + } + + depSpinner.Success("Dependency tree resolved") + + // Display dependency tree + ui.Blank() + fmt.Println(tree.PrintTree(" ")) + + // Get external dependencies + externalModels := tree.GetExternalDependencies() + if len(externalModels) == 0 { + return fmt.Errorf("no external dependencies found for %s", model) + } + + ui.Info(fmt.Sprintf("External models needed (%d):", len(externalModels))) + + for _, m := range externalModels { + fmt.Printf(" • %s\n", m) + } + + ui.Blank() + + // Prompt for network + if network == "" { + network, promptErr = promptForNetwork(labCfg) + if promptErr != nil { + return promptErr + } + } + + // Prompt for spec + if spec == "" { + spec, promptErr = promptForSpec() + if promptErr != nil { + return promptErr + } + } + + // AI-assisted range discovery + ui.Blank() + ui.Header("Analyzing range strategies") + + // Prompt for duration if not specified + if duration == "" { + durationOpts := []ui.SelectOption{ + {Label: "5m", Description: "recommended", Value: "5m"}, + {Label: "30s", Description: "minimal test", Value: "30s"}, + {Label: "1m", Description: "quick test", Value: "1m"}, + {Label: "10m", Description: "", Value: "10m"}, + {Label: "30m", Description: "", Value: "30m"}, + {Label: "1h", Description: "large dataset", Value: "1h"}, + } + + selectedDuration, durationErr := ui.Select("Time range duration", durationOpts) + if durationErr != nil { + return durationErr + } + + duration = selectedDuration + } + + ui.Info(fmt.Sprintf("Using %s time range", duration)) + ui.Info("This may take a few minutes for models with many dependencies - grab a coffee ☕") + + var discoveryResult *seeddata.DiscoveryResult + + // Try AI discovery first + discoveryClient, discoveryErr := seeddata.NewClaudeDiscoveryClient(log, gen) + if discoveryErr != nil { + ui.Warning(fmt.Sprintf("Claude CLI not available: %v", discoveryErr)) + ui.Info("Falling back to heuristic range detection") + + // Fallback to heuristic detection + var rangeInfos map[string]*seeddata.RangeColumnInfo + + rangeInfos, err = seeddata.DetectRangeColumnsForModels(externalModels, labCfg.Repos.XatuCBT) + if err != nil { + return fmt.Errorf("failed to detect range columns: %w", err) + } + + discoveryResult, err = seeddata.FallbackRangeDiscovery(ctx, gen, externalModels, network, rangeInfos, duration, labCfg.Repos.XatuCBT) + if err != nil { + return fmt.Errorf("fallback range discovery failed: %w", err) + } + } else { + // Gather schema information + schemaSpinner := ui.NewSpinner("Gathering schema information") + + schemaInfo, schemaErr := discoveryClient.GatherSchemaInfo(ctx, externalModels, network, labCfg.Repos.XatuCBT) + if schemaErr != nil { + schemaSpinner.Fail("Failed to gather schema info") + + return fmt.Errorf("failed to gather schema info: %w", schemaErr) + } + + schemaSpinner.Success(fmt.Sprintf("Schema info gathered for %d models", len(schemaInfo))) + + // Display detected range info + for _, schema := range schemaInfo { + if schema.RangeInfo != nil { + status := "detected" + if !schema.RangeInfo.Detected { + status = "default" + } + + rangeStr := "" + if schema.RangeInfo.MinValue != "" && schema.RangeInfo.MaxValue != "" { + rangeStr = fmt.Sprintf(" [%s → %s]", schema.RangeInfo.MinValue, schema.RangeInfo.MaxValue) + } + + ui.Info(fmt.Sprintf(" • %s: %s (%s)%s", schema.Model, schema.RangeInfo.Column, status, rangeStr)) + } + } + + // Read transformation SQL + transformationSQL, sqlErr := seeddata.ReadTransformationSQL(model, labCfg.Repos.XatuCBT) + if sqlErr != nil { + return fmt.Errorf("failed to read transformation SQL: %w", sqlErr) + } + + // Read intermediate dependency SQL (for WHERE clause analysis) + intermediateSQL, intErr := seeddata.ReadIntermediateSQL(tree, labCfg.Repos.XatuCBT) + if intErr != nil { + ui.Warning(fmt.Sprintf("Could not read intermediate SQL: %v", intErr)) + // Continue without intermediate SQL - not critical + } + + // Convert to IntermediateSQL slice + var intermediateModels []seeddata.IntermediateSQL + for modelName, sql := range intermediateSQL { + intermediateModels = append(intermediateModels, seeddata.IntermediateSQL{ + Model: modelName, + SQL: sql, + }) + } + + if len(intermediateModels) > 0 { + ui.Info(fmt.Sprintf("Including %d intermediate model(s) for WHERE clause analysis", len(intermediateModels))) + } + + // Invoke Claude for analysis + ui.Blank() + + analysisSpinner := ui.NewSpinner("Analyzing correlation strategy with Claude") + + discoveryResult, err = discoveryClient.AnalyzeRanges(ctx, seeddata.DiscoveryInput{ + TransformationModel: model, + TransformationSQL: transformationSQL, + IntermediateModels: intermediateModels, + Network: network, + Duration: duration, + ExternalModels: schemaInfo, + }) + if err != nil { + analysisSpinner.Fail("AI analysis failed") + ui.Warning(fmt.Sprintf("Claude analysis failed: %v", err)) + ui.Info("Falling back to heuristic range detection") + + // Fallback to heuristic detection + rangeInfos, rangeErr := seeddata.DetectRangeColumnsForModels(externalModels, labCfg.Repos.XatuCBT) + if rangeErr != nil { + return fmt.Errorf("failed to detect range columns: %w", rangeErr) + } + + discoveryResult, err = seeddata.FallbackRangeDiscovery(ctx, gen, externalModels, network, rangeInfos, duration, labCfg.Repos.XatuCBT) + if err != nil { + return fmt.Errorf("fallback range discovery failed: %w", err) + } + } else { + analysisSpinner.Success(fmt.Sprintf("Strategy generated (confidence: %.0f%%)", discoveryResult.OverallConfidence*100)) + } + } + + // Validate that Claude's strategies cover all expected models + // This catches cases where Claude named a model differently + var missingModels []string + + for _, extModel := range externalModels { + if discoveryResult.GetStrategy(extModel) == nil { + missingModels = append(missingModels, extModel) + } + } + + if len(missingModels) > 0 { + ui.Blank() + ui.Warning("The following models are NOT covered by Claude's strategy:") + + for _, m := range missingModels { + ui.Warning(fmt.Sprintf(" • %s", m)) + } + + ui.Warning("These will use the primary range column, which may be incorrect.") + ui.Info("Claude's strategies cover these models:") + + for _, s := range discoveryResult.Strategies { + ui.Info(fmt.Sprintf(" • %s", s.Model)) + } + + ui.Blank() + + proceedMissing, missErr := ui.Confirm("Proceed anyway?") + if missErr != nil { + return missErr + } + + if !proceedMissing { + ui.Info("Aborted. Try regenerating with clearer model names.") + + return nil + } + } + + // Display the proposed strategy + ui.Blank() + ui.Header("Proposed Strategy") + ui.Info(fmt.Sprintf("Summary: %s", discoveryResult.Summary)) + ui.Blank() + ui.Info(fmt.Sprintf("Primary Range: %s (%s)", discoveryResult.PrimaryRangeColumn, discoveryResult.PrimaryRangeType)) + ui.Info(fmt.Sprintf(" From: %s", discoveryResult.FromValue)) + ui.Info(fmt.Sprintf(" To: %s", discoveryResult.ToValue)) + ui.Blank() + + ui.Info("Per-Table Strategies:") + + for _, strategy := range discoveryResult.Strategies { + confidence := fmt.Sprintf("%.0f%%", strategy.Confidence*100) + bridgeInfo := "" + + if strategy.RequiresBridge { + bridgeInfo = fmt.Sprintf(" (via %s)", strategy.BridgeTable) + } + + // Handle dimension tables (no range) vs regular tables + if strategy.RangeColumn == "" || strategy.ColumnType == seeddata.RangeColumnTypeNone { + ui.Info(fmt.Sprintf(" • %s: (dimension table - all data) %s%s", + strategy.Model, + confidence, + bridgeInfo, + )) + } else { + ui.Info(fmt.Sprintf(" • %s: %s [%s → %s] %s%s", + strategy.Model, + strategy.RangeColumn, + strategy.FromValue, + strategy.ToValue, + confidence, + bridgeInfo, + )) + } + + // Display additional filters if present + if strategy.FilterSQL != "" { + ui.Info(fmt.Sprintf(" Filter: %s", strategy.FilterSQL)) + } + + // Display correlation filter if present (for dimension tables) + if strategy.CorrelationFilter != "" { + // Truncate long subqueries for display + corrFilter := strategy.CorrelationFilter + if len(corrFilter) > 80 { + corrFilter = corrFilter[:77] + "..." + } + + ui.Info(fmt.Sprintf(" Correlation: %s", corrFilter)) + } + + // Display if optional + if strategy.Optional { + ui.Info(" (optional - LEFT JOIN)") + } + } + + // Display warnings + if len(discoveryResult.Warnings) > 0 { + ui.Blank() + + for _, warning := range discoveryResult.Warnings { + ui.Warning(warning) + } + } + + // Warn if low confidence + if discoveryResult.OverallConfidence < 0.5 { + ui.Blank() + ui.Warning("Low confidence score - manual review recommended") + } + + // Validate that each model has data in the proposed range + ui.Blank() + ui.Header("Validating data availability") + + validationSpinner := ui.NewSpinner("Checking row counts for each model") + + validation, validationErr := gen.ValidateStrategyHasData(ctx, discoveryResult, network) + if validationErr != nil { + validationSpinner.Fail("Validation failed") + + return fmt.Errorf("failed to validate strategy: %w", validationErr) + } + + validationSpinner.Success("Validation complete") + + // Display row counts + ui.Blank() + ui.Info("Data availability per model:") + + for _, count := range validation.Counts { + status := "✓" + if !count.HasData { + status = "✗" + } + + if count.Error != nil { + ui.Warning(fmt.Sprintf(" %s %s: error - %v", status, count.Model, count.Error)) + } else { + ui.Info(fmt.Sprintf(" %s %s: %d rows", status, count.Model, count.RowCount)) + } + } + + ui.Blank() + ui.Info(fmt.Sprintf("Total rows across all models: %d", validation.TotalRows)) + + if validation.MinRowModel != "" { + ui.Info(fmt.Sprintf("Model with fewest rows: %s (%d rows)", validation.MinRowModel, validation.MinRowCount)) + } + + // Handle errored models (timeouts, etc.) + if len(validation.ErroredModels) > 0 { + ui.Blank() + ui.Error("The following models FAILED to query (timeout or error):") + + for _, model := range validation.ErroredModels { + ui.Error(fmt.Sprintf(" • %s", model)) + } + + ui.Blank() + ui.Warning("These queries timed out - the tables may be too large or the range too wide.") + ui.Warning("Consider narrowing the block/time range, or proceed if you believe the data exists.") + + proceedWithErrors, errErr := ui.Confirm("Proceed anyway (assuming data exists)?") + if errErr != nil { + return errErr + } + + if !proceedWithErrors { + ui.Info("Aborted by user. Try a narrower range.") + + return nil + } + } + + // Handle empty models (zero rows) + if len(validation.EmptyModels) > 0 { + ui.Blank() + ui.Error("The following models have NO DATA in the proposed range:") + + for _, model := range validation.EmptyModels { + ui.Error(fmt.Sprintf(" • %s", model)) + } + + ui.Blank() + ui.Warning("Empty parquets will be generated for these models, which may cause test failures.") + + expandWindow, expandErr := ui.Confirm("Would you like to expand the time window and retry?") + if expandErr != nil { + return expandErr + } + + if expandWindow { + ui.Info("Please re-run the command with a larger time window or different range.") + ui.Info("Tip: Some tables (like canonical_execution_contracts) may have sparse data.") + + return nil + } + + // Let user proceed anyway if they want + proceedAnyway, proceedErr := ui.Confirm("Proceed anyway with potentially empty data?") + if proceedErr != nil { + return proceedErr + } + + if !proceedAnyway { + ui.Info("Aborted by user") + + return nil + } + } + + // User confirmation + ui.Blank() + + proceed, confirmErr := ui.Confirm("Proceed with this strategy?") + if confirmErr != nil { + return confirmErr + } + + if !proceed { + ui.Info("Aborted by user") + + return nil + } + + // Row limit handling: + // - With AI discovery: use unlimited (0) since Claude already picked sensible ranges + // - Manual/fallback: prompt for limit to avoid accidentally pulling too much data + // - Explicit --limit flag always respected + if discoveryResult != nil && limit == defaultRowLimit { + // AI discovery mode: no limit needed, Claude picked appropriate ranges + limit = 0 + + ui.Info("Using unlimited rows (AI discovery already optimized the range)") + } else if limit == defaultRowLimit { + // Fallback/manual mode: prompt for safety + limit, promptErr = promptForLimit() + if promptErr != nil { + return promptErr + } + } + + // Prompt for upload + if !upload { + upload, promptErr = ui.Confirm("Upload to S3?") + if promptErr != nil { + return promptErr + } + } + + // S3 preflight check if uploading + var uploader *seeddata.S3Uploader + + if upload { + s3Spinner := ui.NewSpinner("Checking S3 access") + + uploader, err = seeddata.NewS3Uploader(ctx, log) + if err != nil { + s3Spinner.Fail("Failed to initialize S3 client") + + return fmt.Errorf("failed to create S3 uploader: %w", err) + } + + if accessErr := uploader.CheckAccess(ctx); accessErr != nil { + s3Spinner.Fail("S3 access check failed") + + return fmt.Errorf("S3 preflight check failed: %w", accessErr) + } + + s3Spinner.Success("S3 access verified") + } + + // Generate salt for IP sanitization if enabled + var salt string + + if sanitizeIPs { + var saltErr error + + salt, saltErr = seeddata.GenerateSalt() + if saltErr != nil { + return fmt.Errorf("failed to generate salt for IP sanitization: %w", saltErr) + } + } + + // Generate seed data for all external models + ui.Blank() + ui.Header("Generating seed data") + + urls := make(map[string]string, len(externalModels)) + + for _, extModel := range externalModels { + // Get the strategy for this model + strategy := discoveryResult.GetStrategy(extModel) + if strategy == nil { + ui.Warning(fmt.Sprintf("No strategy found for %s, using defaults", extModel)) + + // Detect the correct range column for this model instead of blindly using primary + detectedCol, detectErr := seeddata.DetectRangeColumnForModel(extModel, labCfg.Repos.XatuCBT) + if detectErr != nil { + ui.Warning(fmt.Sprintf("Could not detect range column for %s: %v", extModel, detectErr)) + + detectedCol = seeddata.DefaultRangeColumn + } + + // Check if detected column type matches primary range type + colLower := strings.ToLower(detectedCol) + isTimeColumn := strings.Contains(colLower, "date") || strings.Contains(colLower, "time") + primaryIsTime := discoveryResult.PrimaryRangeType == seeddata.RangeColumnTypeTime + + if isTimeColumn != primaryIsTime { + // Column types don't match - we can't use primary range values + ui.Error(fmt.Sprintf(" %s uses %s but primary range is %s - cannot convert automatically", + extModel, detectedCol, discoveryResult.PrimaryRangeColumn)) + ui.Error(" Please re-run with Claude to get proper correlation, or use --range-column to override") + + return fmt.Errorf("cannot generate %s: range column type mismatch", extModel) + } + + strategy = &seeddata.TableRangeStrategy{ + Model: extModel, + RangeColumn: detectedCol, + FromValue: discoveryResult.FromValue, + ToValue: discoveryResult.ToValue, + } + } + + filename := seeddata.GetParquetFilename(model, extModel) + outputPath := fmt.Sprintf("./%s", filename) + + // Check if we should skip existing + if upload && skipExisting && uploader != nil { + exists, existsErr := uploader.ObjectExists(ctx, network, spec, filename[:len(filename)-8]) // Remove .parquet + if existsErr == nil && exists { + ui.Info(fmt.Sprintf(" ⏭ Skipping %s (already exists)", extModel)) + + urls[extModel] = uploader.GetPublicURL(network, spec, filename[:len(filename)-8]) + + continue + } + } + + // Show query parameters (helps debug empty parquets) + filterInfo := "" + if strategy.FilterSQL != "" { + filterInfo = fmt.Sprintf(" + filter: %s", strategy.FilterSQL) + } + + if strategy.CorrelationFilter != "" { + // Truncate long subqueries for display + corrFilter := strategy.CorrelationFilter + if len(corrFilter) > 60 { + corrFilter = corrFilter[:57] + "..." + } + + filterInfo += fmt.Sprintf(" + correlation: %s", corrFilter) + } + + // Handle dimension tables (no range) vs regular tables + if strategy.RangeColumn == "" || strategy.ColumnType == seeddata.RangeColumnTypeNone { + if strategy.CorrelationFilter != "" { + ui.Info(fmt.Sprintf(" %s: (correlated dimension table)%s", extModel, filterInfo)) + } else { + ui.Info(fmt.Sprintf(" %s: (dimension table - all data)%s", extModel, filterInfo)) + } + } else { + ui.Info(fmt.Sprintf(" %s: %s [%s → %s]%s", extModel, strategy.RangeColumn, strategy.FromValue, strategy.ToValue, filterInfo)) + } + + genSpinner := ui.NewSpinner(fmt.Sprintf("Generating %s", extModel)) + + result, genErr := gen.Generate(ctx, seeddata.GenerateOptions{ + Model: extModel, + Network: network, + Spec: spec, + RangeColumn: strategy.RangeColumn, + From: strategy.FromValue, + To: strategy.ToValue, + FilterSQL: strategy.FilterSQL, + CorrelationFilter: strategy.CorrelationFilter, + Limit: limit, + OutputPath: outputPath, + SanitizeIPs: sanitizeIPs, + Salt: salt, + }) + if genErr != nil { + genSpinner.Fail(fmt.Sprintf("Failed to generate %s", extModel)) + + return fmt.Errorf("failed to generate seed data for %s: %w", extModel, genErr) + } + + genSpinner.Success(fmt.Sprintf("%s (%s)", extModel, formatFileSize(result.FileSize))) + + // Warn if file is too large for comfortable test imports + const largeFileThreshold = 15 * 1024 * 1024 // 15MB + if result.FileSize > largeFileThreshold { + ui.Warning(fmt.Sprintf(" Large file (%s) - may slow down tests on low-powered machines. Consider using a shorter duration.", + formatFileSize(result.FileSize))) + } + + // Show query for first model to help debug empty parquets + if extModel == externalModels[0] { + ui.Info(fmt.Sprintf(" Query: %s", result.Query)) + } + + // Display sanitized columns if any + if len(result.SanitizedColumns) > 0 { + ui.Info(fmt.Sprintf(" Sanitized IP columns: %v", result.SanitizedColumns)) + } + + // Upload if requested + if upload && uploader != nil { + uploadSpinner := ui.NewSpinner(fmt.Sprintf("Uploading %s", extModel)) + + uploadResult, uploadErr := uploader.Upload(ctx, seeddata.UploadOptions{ + LocalPath: outputPath, + Network: network, + Spec: spec, + Model: extModel, + Filename: filename[:len(filename)-8], // Remove .parquet extension + }) + if uploadErr != nil { + uploadSpinner.Fail(fmt.Sprintf("Failed to upload %s", extModel)) + + return fmt.Errorf("failed to upload %s: %w", extModel, uploadErr) + } + + uploadSpinner.Success(fmt.Sprintf("Uploaded %s", extModel)) + ui.Info(fmt.Sprintf(" → %s", uploadResult.PublicURL)) + + urls[extModel] = uploadResult.PublicURL + + // Clean up local file + if removeErr := os.Remove(outputPath); removeErr != nil { + ui.Warning(fmt.Sprintf("Could not remove local file: %v", removeErr)) + } + } else { + // Use placeholder URL + urls[extModel] = fmt.Sprintf("https://%s/%s/%s/%s/%s", + seeddata.DefaultS3PublicDomain, seeddata.DefaultS3Prefix, network, spec, filename) + } + } + + // Generate assertions + var assertions []seeddata.Assertion + + if aiAssertions { + assertions, err = generateAIAssertions(ctx, log, model, externalModels, labCfg.Repos.XatuCBT) + if err != nil { + ui.Warning(fmt.Sprintf("AI assertion generation failed: %v", err)) + ui.Info("Using default assertions instead") + + assertions = seeddata.GetDefaultAssertions(model) + } + } else { + // Prompt for AI assertions + useAI, confirmErr := ui.Confirm("Generate assertions with Claude?") + if confirmErr != nil { + return confirmErr + } + + if useAI { + assertions, err = generateAIAssertions(ctx, log, model, externalModels, labCfg.Repos.XatuCBT) + if err != nil { + ui.Warning(fmt.Sprintf("AI assertion generation failed: %v", err)) + + assertions = seeddata.GetDefaultAssertions(model) + } + } else { + assertions = seeddata.GetDefaultAssertions(model) + } + } + + // Generate test YAML + ui.Blank() + ui.Header("Generating test YAML") + + yamlContent, err := seeddata.GenerateTransformationTestYAML(seeddata.TransformationTemplateData{ + Model: model, + Network: network, + Spec: spec, + ExternalModels: externalModels, + URLs: urls, + Assertions: assertions, + }) + if err != nil { + return fmt.Errorf("failed to generate YAML: %w", err) + } + + // Prompt to write YAML to xatu-cbt + writeYAML, writeErr := ui.Confirm("Write test YAML to xatu-cbt?") + if writeErr != nil { + return writeErr + } + + if writeYAML { + yamlPath := filepath.Join(labCfg.Repos.XatuCBT, "tests", network, spec, "models", model+".yaml") + + if yamlWriteErr := writeTestYAML(yamlPath, yamlContent); yamlWriteErr != nil { + return yamlWriteErr + } + } + + // Display test command + ui.Blank() + ui.Header("Test Command") + + testCmd := fmt.Sprintf("./bin/xatu-cbt test models %s --spec %s --network %s --verbose --force-rebuild", + model, spec, network) + fmt.Println(testCmd) + + // Display test YAML + ui.Blank() + ui.Header("Test YAML") + + fmt.Println(yamlContent) + + if !upload { + ui.Blank() + ui.Warning("Files were not uploaded. Update the URLs in the YAML after uploading manually.") + } + + return nil +} + +func promptForTransformationModel(xatuCBTPath string) (string, error) { + models, err := seeddata.ListTransformationModels(xatuCBTPath) + if err != nil { + return "", fmt.Errorf("failed to list transformation models: %w", err) + } + + options := make([]ui.SelectOption, 0, len(models)) + + for _, m := range models { + options = append(options, ui.SelectOption{ + Label: m, + Value: m, + }) + } + + return ui.Select("Select transformation model", options) +} + +func generateAIAssertions(ctx context.Context, log logrus.FieldLogger, model string, externalModels []string, xatuCBTPath string) ([]seeddata.Assertion, error) { + aiSpinner := ui.NewSpinner("Analyzing transformation SQL with Claude") + + client, err := seeddata.NewClaudeAssertionClient(log) + if err != nil { + aiSpinner.Fail("Claude CLI not available") + + return nil, fmt.Errorf("claude CLI not available: %w", err) + } + + // Read transformation SQL + sqlPath := filepath.Join(xatuCBTPath, "models", "transformations", model+".sql") + + sqlContent, err := os.ReadFile(sqlPath) + if err != nil { + aiSpinner.Fail("Failed to read transformation SQL") + + return nil, fmt.Errorf("failed to read SQL file: %w", err) + } + + assertions, err := client.GenerateAssertions(ctx, string(sqlContent), externalModels, model) + if err != nil { + aiSpinner.Fail("Claude assertion generation failed") + + return nil, err + } + + aiSpinner.Success(fmt.Sprintf("Generated %d assertions", len(assertions))) + + return assertions, nil +} diff --git a/pkg/seeddata/assertions.go b/pkg/seeddata/assertions.go new file mode 100644 index 0000000..f504775 --- /dev/null +++ b/pkg/seeddata/assertions.go @@ -0,0 +1,312 @@ +package seeddata + +import ( + "bytes" + "context" + "fmt" + "os" + "os/exec" + "path/filepath" + "regexp" + "strings" + "time" + + "github.com/sirupsen/logrus" + "gopkg.in/yaml.v3" +) + +// AssertionCheck represents a single assertion check within an assertion. +type AssertionCheck struct { + Type string `yaml:"type"` // greater_than, less_than, equals, etc. + Column string `yaml:"column"` // Column name to check + Value any `yaml:"value"` // Value to compare against +} + +// Assertion represents a test assertion for a transformation model. +type Assertion struct { + Name string `yaml:"name"` + SQL string `yaml:"sql"` + Assertions []AssertionCheck `yaml:"assertions,omitempty"` // For dynamic assertions + Expected map[string]any `yaml:"expected,omitempty"` // For exact value assertions +} + +// ClaudeAssertionClient handles assertion generation using Claude CLI. +type ClaudeAssertionClient struct { + log logrus.FieldLogger + claudePath string + timeout time.Duration +} + +// NewClaudeAssertionClient creates a new Claude client for assertion generation. +func NewClaudeAssertionClient(log logrus.FieldLogger) (*ClaudeAssertionClient, error) { + claudePath, err := findClaudeBinaryPath() + if err != nil { + return nil, fmt.Errorf("claude CLI not found: %w", err) + } + + return &ClaudeAssertionClient{ + log: log.WithField("component", "claude-assertions"), + claudePath: claudePath, + timeout: 3 * time.Minute, // Assertion generation can take time + }, nil +} + +// IsAvailable checks if Claude Code CLI is installed and available. +func (c *ClaudeAssertionClient) IsAvailable() bool { + if c.claudePath == "" { + return false + } + + info, err := os.Stat(c.claudePath) + if err != nil { + return false + } + + return !info.IsDir() && info.Mode()&0111 != 0 +} + +// GenerateAssertions uses Claude to analyze transformation SQL and suggest assertions. +func (c *ClaudeAssertionClient) GenerateAssertions(ctx context.Context, transformationSQL string, externalModels []string, modelName string) ([]Assertion, error) { + if !c.IsAvailable() { + return nil, fmt.Errorf("claude CLI is not available") + } + + prompt := c.buildAssertionPrompt(transformationSQL, externalModels, modelName) + + ctx, cancel := context.WithTimeout(ctx, c.timeout) + defer cancel() + + //nolint:gosec // claudePath is validated in findClaudeBinaryPath + cmd := exec.CommandContext(ctx, c.claudePath, "--print") + cmd.Stdin = strings.NewReader(prompt) + + var stdout, stderr bytes.Buffer + + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + c.log.WithFields(logrus.Fields{ + "timeout": c.timeout, + "model": modelName, + }).Debug("invoking Claude CLI for assertion generation") + + if err := cmd.Run(); err != nil { + if ctx.Err() == context.DeadlineExceeded { + return nil, fmt.Errorf("claude assertion generation timed out after %s", c.timeout) + } + + return nil, fmt.Errorf("claude CLI failed: %w (stderr: %s)", err, stderr.String()) + } + + response := stdout.String() + if response == "" { + return nil, fmt.Errorf("claude returned empty response") + } + + c.log.WithField("response_length", len(response)).Debug("received Claude response") + + return c.parseAssertionResponse(response, modelName) +} + +// buildAssertionPrompt creates the prompt for Claude to generate assertions. +func (c *ClaudeAssertionClient) buildAssertionPrompt(transformationSQL string, externalModels []string, modelName string) string { + var sb strings.Builder + + sb.WriteString("Generate test assertions for this ClickHouse transformation model.\n\n") + sb.WriteString("## Instructions\n") + sb.WriteString("You are analyzing a ClickHouse transformation SQL to generate test assertions. ") + sb.WriteString("Output ONLY valid YAML that can be directly parsed. No explanations or markdown code blocks.\n\n") + sb.WriteString("Generate assertions that verify:\n") + sb.WriteString("1. Row count is greater than zero\n") + sb.WriteString("2. Key columns have valid ranges based on the SQL logic\n") + sb.WriteString("3. Aggregations are mathematically correct\n") + sb.WriteString("4. No data quality issues (nulls where unexpected, negative values for counts, etc.)\n\n") + + sb.WriteString("## Output Format\n") + sb.WriteString("Output assertions as a YAML list. Each assertion must have:\n") + sb.WriteString("- name: descriptive name\n") + sb.WriteString("- sql: the query (use `") + sb.WriteString(modelName) + sb.WriteString(" FINAL` for the table name)\n") + sb.WriteString("- assertions: list of checks with type, column, value\n\n") + + sb.WriteString("Valid assertion types: greater_than, less_than, greater_than_or_equal, less_than_or_equal, equals\n\n") + + sb.WriteString("Example output format:\n") + sb.WriteString("- name: Row count should be greater than zero\n") + sb.WriteString(" sql: |\n") + sb.WriteString(" SELECT COUNT(*) AS count FROM ") + sb.WriteString(modelName) + sb.WriteString(" FINAL\n") + sb.WriteString(" assertions:\n") + sb.WriteString(" - type: greater_than\n") + sb.WriteString(" column: count\n") + sb.WriteString(" value: 0\n\n") + + sb.WriteString("## Transformation Model: ") + sb.WriteString(modelName) + sb.WriteString("\n\n") + + sb.WriteString("## External Dependencies\n") + + for _, model := range externalModels { + sb.WriteString("- ") + sb.WriteString(model) + sb.WriteString("\n") + } + + sb.WriteString("\n## Transformation SQL\n```sql\n") + sb.WriteString(transformationSQL) + sb.WriteString("\n```\n\n") + + sb.WriteString("Generate 5-10 meaningful assertions. Output ONLY the YAML list, no other text.\n") + + return sb.String() +} + +// parseAssertionResponse parses Claude's YAML response into assertions. +func (c *ClaudeAssertionClient) parseAssertionResponse(response, modelName string) ([]Assertion, error) { + // Try to extract YAML from the response + yamlContent := extractYAMLFromResponse(response) + if yamlContent == "" { + return nil, fmt.Errorf("no valid YAML found in Claude response") + } + + var assertions []Assertion + + if unmarshalErr := yaml.Unmarshal([]byte(yamlContent), &assertions); unmarshalErr != nil { + // Try wrapping in a list if it failed + if !strings.HasPrefix(strings.TrimSpace(yamlContent), "-") { + yamlContent = "- " + strings.ReplaceAll(yamlContent, "\n", "\n ") + + if retryErr := yaml.Unmarshal([]byte(yamlContent), &assertions); retryErr != nil { + return nil, fmt.Errorf("failed to parse assertions YAML: %w", retryErr) + } + } else { + return nil, fmt.Errorf("failed to parse assertions YAML: %w", unmarshalErr) + } + } + + // Validate and clean up assertions + validAssertions := make([]Assertion, 0, len(assertions)) + + for _, a := range assertions { + if a.Name == "" || a.SQL == "" { + continue + } + + // Ensure table name uses FINAL + if !strings.Contains(a.SQL, "FINAL") { + a.SQL = strings.ReplaceAll(a.SQL, modelName, modelName+" FINAL") + } + + validAssertions = append(validAssertions, a) + } + + if len(validAssertions) == 0 { + return nil, fmt.Errorf("no valid assertions parsed from Claude response") + } + + return validAssertions, nil +} + +// extractYAMLFromResponse extracts YAML content from Claude's response. +func extractYAMLFromResponse(response string) string { + // First, try to find YAML in code blocks + codeBlockRe := regexp.MustCompile("```(?:yaml|yml)?\\n([\\s\\S]*?)```") + matches := codeBlockRe.FindStringSubmatch(response) + + if len(matches) > 1 { + return strings.TrimSpace(matches[1]) + } + + // If no code block, look for content starting with a YAML list or discovery YAML + lines := strings.Split(response, "\n") + + var yamlLines []string + + inYAML := false + + for _, line := range lines { + trimmed := strings.TrimSpace(line) + + // Start YAML when we see assertion-style or discovery-style start markers + // Include common typos/variations Claude might output + if strings.HasPrefix(trimmed, "- name:") || + strings.HasPrefix(trimmed, "primaryRangeType:") || + strings.HasPrefix(trimmed, "primaryrangeType:") || + strings.HasPrefix(trimmed, "primary_range_type:") { + inYAML = true + } + + if inYAML { + // Stop if we hit non-YAML content (text that doesn't look like YAML) + if trimmed != "" && + !strings.HasPrefix(line, " ") && + !strings.HasPrefix(line, "-") && + !strings.HasPrefix(line, "\t") && + !strings.Contains(line, ":") { + break + } + + yamlLines = append(yamlLines, line) + } + } + + if len(yamlLines) > 0 { + return strings.Join(yamlLines, "\n") + } + + // Last resort: return trimmed response hoping it's valid YAML + return strings.TrimSpace(response) +} + +// findClaudeBinaryPath locates the claude CLI binary. +func findClaudeBinaryPath() (string, error) { + // First, try `which claude` + whichCmd := exec.Command("which", "claude") + + output, err := whichCmd.Output() + if err == nil { + path := strings.TrimSpace(string(output)) + if path != "" { + return path, nil + } + } + + // Get home directory for path construction + home, err := os.UserHomeDir() + if err != nil { + home = os.Getenv("HOME") + } + + // Search in common locations + searchPaths := []string{ + filepath.Join(home, ".volta", "bin", "claude"), + "/usr/local/bin/claude", + filepath.Join(home, "go", "bin", "claude"), + filepath.Join(home, ".local", "bin", "claude"), + "/opt/homebrew/bin/claude", + } + + for _, path := range searchPaths { + if info, err := os.Stat(path); err == nil && !info.IsDir() { + return path, nil + } + } + + return "", fmt.Errorf("claude binary not found in PATH or common locations") +} + +// GetDefaultAssertions returns basic default assertions when Claude is not available. +func GetDefaultAssertions(modelName string) []Assertion { + return []Assertion{ + { + Name: "Row count should be greater than zero", + SQL: fmt.Sprintf(`SELECT COUNT(*) AS count FROM %s FINAL`, modelName), + Assertions: []AssertionCheck{ + {Type: "greater_than", Column: "count", Value: 0}, + }, + }, + } +} diff --git a/pkg/seeddata/batch.go b/pkg/seeddata/batch.go new file mode 100644 index 0000000..7c4e318 --- /dev/null +++ b/pkg/seeddata/batch.go @@ -0,0 +1,89 @@ +package seeddata + +import ( + "context" + "fmt" + + "github.com/sirupsen/logrus" +) + +// BatchGenerateOptions contains options for batch seed data generation. +type BatchGenerateOptions struct { + TransformationModel string // The transformation model name (used for parquet naming) + ExternalModels []string // List of external model names to generate + Network string // Network name (e.g., "mainnet", "sepolia") + Spec string // Fork spec (e.g., "pectra", "fusaka") + RangeColumn string // Column to filter on (e.g., "slot_start_date_time") + From string // Range start value + To string // Range end value + Filters []Filter // Additional filters (applied to all models) + Limit int // Max rows per model (0 = unlimited) + OutputDir string // Output directory for parquet files + SanitizeIPs bool // Enable IP address sanitization + Salt string // Salt for IP sanitization (shared across all models) +} + +// BatchGenerateResult contains the result of batch seed data generation. +type BatchGenerateResult struct { + Results map[string]*GenerateResult // model name -> generate result +} + +// BatchGenerate generates seed data for all external models with consistent range. +// Parquet files are named as {transformation}_{external_model}.parquet. +func (g *Generator) BatchGenerate(ctx context.Context, opts BatchGenerateOptions) (*BatchGenerateResult, error) { + if len(opts.ExternalModels) == 0 { + return nil, fmt.Errorf("no external models specified") + } + + if opts.Network == "" { + return nil, fmt.Errorf("network is required") + } + + if opts.OutputDir == "" { + opts.OutputDir = "." + } + + result := &BatchGenerateResult{ + Results: make(map[string]*GenerateResult, len(opts.ExternalModels)), + } + + for _, model := range opts.ExternalModels { + // Build output filename: {transformation}_{external_model}.parquet + filename := fmt.Sprintf("%s_%s.parquet", opts.TransformationModel, model) + outputPath := fmt.Sprintf("%s/%s", opts.OutputDir, filename) + + g.log.WithFields(logrus.Fields{ + "transformation": opts.TransformationModel, + "external_model": model, + "output": outputPath, + }).Info("generating seed data for external model") + + genOpts := GenerateOptions{ + Model: model, + Network: opts.Network, + Spec: opts.Spec, + RangeColumn: opts.RangeColumn, + From: opts.From, + To: opts.To, + Filters: opts.Filters, + Limit: opts.Limit, + OutputPath: outputPath, + SanitizeIPs: opts.SanitizeIPs, + Salt: opts.Salt, + } + + genResult, err := g.Generate(ctx, genOpts) + if err != nil { + return nil, fmt.Errorf("failed to generate seed data for %s: %w", model, err) + } + + result.Results[model] = genResult + } + + return result, nil +} + +// GetParquetFilename returns the parquet filename for an external model within a transformation. +func GetParquetFilename(transformationModel, externalModel string) string { + return fmt.Sprintf("%s_%s.parquet", transformationModel, externalModel) +} diff --git a/pkg/seeddata/dependencies.go b/pkg/seeddata/dependencies.go new file mode 100644 index 0000000..d453539 --- /dev/null +++ b/pkg/seeddata/dependencies.go @@ -0,0 +1,461 @@ +// Package seeddata provides functionality to generate seed data parquet files +// for xatu-cbt tests by extracting data from external ClickHouse. +package seeddata + +import ( + "bufio" + "fmt" + "os" + "path/filepath" + "regexp" + "strings" + + "gopkg.in/yaml.v3" +) + +// DependencyType represents the type of a dependency. +type DependencyType string + +const ( + // DependencyTypeExternal represents an external model dependency. + DependencyTypeExternal DependencyType = "external" + // DependencyTypeTransformation represents a transformation model dependency. + DependencyTypeTransformation DependencyType = "transformation" +) + +// Dependency represents a single model dependency. +type Dependency struct { + Type DependencyType + Name string +} + +// DependencyTree represents a model and its dependencies. +type DependencyTree struct { + Model string + Type DependencyType + Dependencies []Dependency + Children map[string]*DependencyTree // For transformation deps (recursive) +} + +// dependencyPattern matches dependency strings like "{{transformation}}.model_name" or "{{external}}.model_name". +var dependencyPattern = regexp.MustCompile(`^\{\{(transformation|external)\}\}\.(.+)$`) + +// IntervalType represents the interval type from external model frontmatter. +type IntervalType string + +const ( + // IntervalTypeSlot is slot-based interval (time ranges via slot_start_date_time). + IntervalTypeSlot IntervalType = "slot" + // IntervalTypeBlock is block number based interval. + IntervalTypeBlock IntervalType = "block" + // IntervalTypeEntity is for dimension/reference tables with no time range. + IntervalTypeEntity IntervalType = "entity" +) + +// intervalConfig represents the interval configuration in model frontmatter. +type intervalConfig struct { + Type IntervalType `yaml:"type"` +} + +// sqlFrontmatter represents the YAML frontmatter in SQL files. +type sqlFrontmatter struct { + Table string `yaml:"table"` + Dependencies []string `yaml:"dependencies"` + Interval intervalConfig `yaml:"interval"` +} + +// ParseDependencies parses the dependencies from a SQL file's YAML frontmatter. +func ParseDependencies(sqlPath string) ([]Dependency, error) { + frontmatter, err := parseFrontmatter(sqlPath) + if err != nil { + return nil, err + } + + deps := make([]Dependency, 0, len(frontmatter.Dependencies)) + + for _, depStr := range frontmatter.Dependencies { + dep, err := parseDependencyString(depStr) + if err != nil { + return nil, fmt.Errorf("invalid dependency '%s': %w", depStr, err) + } + + deps = append(deps, dep) + } + + return deps, nil +} + +// ResolveDependencyTree recursively resolves all dependencies for a transformation model. +// It returns a tree structure with all dependencies, where external models are leaf nodes. +func ResolveDependencyTree(model string, xatuCBTPath string, visited map[string]bool) (*DependencyTree, error) { + // Initialize visited map if nil (first call) + if visited == nil { + visited = make(map[string]bool, 16) + } + + // Check for circular dependencies + if visited[model] { + return nil, fmt.Errorf("circular dependency detected: %s", model) + } + + visited[model] = true + + defer func() { visited[model] = false }() + + // Try to find as transformation model (supports .sql and .yml extensions) + transformationPath := findModelFile(xatuCBTPath, "transformations", model) + + if transformationPath != "" { + return resolveTransformationTree(model, transformationPath, xatuCBTPath, visited) + } + + // If not found as transformation, check if it's an external model + externalPath := findModelFile(xatuCBTPath, "external", model) + if externalPath != "" { + return &DependencyTree{ + Model: model, + Type: DependencyTypeExternal, + Dependencies: nil, + Children: nil, + }, nil + } + + return nil, fmt.Errorf("model '%s' not found in transformations or external models", model) +} + +// findModelFile looks for a model file with supported extensions (.sql, .yml, .yaml). +func findModelFile(xatuCBTPath, folder, model string) string { + extensions := []string{".sql", ".yml", ".yaml"} + + for _, ext := range extensions { + path := filepath.Join(xatuCBTPath, "models", folder, model+ext) + if _, err := os.Stat(path); err == nil { + return path + } + } + + return "" +} + +// resolveTransformationTree resolves a transformation model's dependency tree. +func resolveTransformationTree(model, sqlPath, xatuCBTPath string, visited map[string]bool) (*DependencyTree, error) { + deps, err := ParseDependencies(sqlPath) + if err != nil { + return nil, fmt.Errorf("failed to parse dependencies for %s: %w", model, err) + } + + tree := &DependencyTree{ + Model: model, + Type: DependencyTypeTransformation, + Dependencies: deps, + Children: make(map[string]*DependencyTree, len(deps)), + } + + // Recursively resolve transformation dependencies + for _, dep := range deps { + if dep.Type == DependencyTypeTransformation { + childTree, err := ResolveDependencyTree(dep.Name, xatuCBTPath, visited) + if err != nil { + return nil, fmt.Errorf("failed to resolve dependency %s: %w", dep.Name, err) + } + + tree.Children[dep.Name] = childTree + } else { + // External dependencies are leaf nodes + tree.Children[dep.Name] = &DependencyTree{ + Model: dep.Name, + Type: DependencyTypeExternal, + Dependencies: nil, + Children: nil, + } + } + } + + return tree, nil +} + +// GetExternalDependencies returns all external model names from the dependency tree (leaf nodes). +// The result is deduplicated. +func (t *DependencyTree) GetExternalDependencies() []string { + seen := make(map[string]bool, 8) + externals := make([]string, 0, 8) + + t.collectExternalDeps(seen, &externals) + + return externals +} + +// collectExternalDeps recursively collects external dependencies. +func (t *DependencyTree) collectExternalDeps(seen map[string]bool, result *[]string) { + if t.Type == DependencyTypeExternal { + if !seen[t.Model] { + seen[t.Model] = true + *result = append(*result, t.Model) + } + + return + } + + for _, child := range t.Children { + child.collectExternalDeps(seen, result) + } +} + +// PrintTree returns a string representation of the dependency tree. +func (t *DependencyTree) PrintTree(indent string) string { + var sb strings.Builder + + typeStr := "transformation" + if t.Type == DependencyTypeExternal { + typeStr = "external" + } + + sb.WriteString(fmt.Sprintf("%s%s ({{%s}})\n", indent, t.Model, typeStr)) + + childIndent := indent + " " + + for _, dep := range t.Dependencies { + if child, ok := t.Children[dep.Name]; ok { + sb.WriteString(child.PrintTree(childIndent)) + } + } + + return sb.String() +} + +// GetIntermediateDependencies returns all intermediate (transformation) model names from the +// dependency tree, excluding the root model. These are non-leaf nodes that transform external data. +// The result is deduplicated. +func (t *DependencyTree) GetIntermediateDependencies() []string { + seen := make(map[string]bool, 8) + intermediates := make([]string, 0, 8) + + t.collectIntermediateDeps(seen, &intermediates, true) + + return intermediates +} + +// collectIntermediateDeps recursively collects intermediate (transformation) dependencies. +func (t *DependencyTree) collectIntermediateDeps(seen map[string]bool, result *[]string, isRoot bool) { + // Skip external models (leaf nodes) + if t.Type == DependencyTypeExternal { + return + } + + // Add this model if it's not the root and not already seen + if !isRoot && !seen[t.Model] { + seen[t.Model] = true + *result = append(*result, t.Model) + } + + // Recurse into children + for _, child := range t.Children { + child.collectIntermediateDeps(seen, result, false) + } +} + +// ReadIntermediateSQL reads the SQL content for all intermediate dependencies. +// Returns a map of model name to SQL content. +// Note: YAML script models (.yml/.yaml) are skipped as they don't contain SQL to analyze. +func ReadIntermediateSQL(tree *DependencyTree, xatuCBTPath string) (map[string]string, error) { + intermediates := tree.GetIntermediateDependencies() + result := make(map[string]string, len(intermediates)) + + for _, model := range intermediates { + modelPath := findModelFile(xatuCBTPath, "transformations", model) + if modelPath == "" { + // Model file not found, skip + continue + } + + // Only read SQL files - YAML script models don't have SQL to analyze + if !strings.HasSuffix(modelPath, ".sql") { + continue + } + + content, err := os.ReadFile(modelPath) + if err != nil { + return nil, fmt.Errorf("failed to read SQL for %s: %w", model, err) + } + + result[model] = string(content) + } + + return result, nil +} + +// ListTransformationModels returns a list of available transformation models from the xatu-cbt repo. +func ListTransformationModels(xatuCBTPath string) ([]string, error) { + modelsDir := filepath.Join(xatuCBTPath, "models", "transformations") + + entries, err := os.ReadDir(modelsDir) + if err != nil { + return nil, fmt.Errorf("failed to read transformations directory: %w", err) + } + + models := make([]string, 0, len(entries)) + + for _, entry := range entries { + if entry.IsDir() { + continue + } + + name := entry.Name() + + // Support .sql, .yml, and .yaml extensions + for _, ext := range []string{".sql", ".yml", ".yaml"} { + if strings.HasSuffix(name, ext) { + models = append(models, strings.TrimSuffix(name, ext)) + + break + } + } + } + + return models, nil +} + +// parseFrontmatter extracts and parses the YAML frontmatter from a SQL file, +// or parses a pure YAML file (.yml/.yaml) directly. +func parseFrontmatter(modelPath string) (*sqlFrontmatter, error) { + // Check if this is a pure YAML file (not SQL with frontmatter) + if strings.HasSuffix(modelPath, ".yml") || strings.HasSuffix(modelPath, ".yaml") { + return parseYAMLFile(modelPath) + } + + // Parse SQL file with YAML frontmatter + return parseSQLFrontmatter(modelPath) +} + +// parseYAMLFile parses a pure YAML model file (.yml or .yaml). +func parseYAMLFile(yamlPath string) (*sqlFrontmatter, error) { + content, err := os.ReadFile(yamlPath) + if err != nil { + return nil, fmt.Errorf("failed to read file: %w", err) + } + + var fm sqlFrontmatter + + if err := yaml.Unmarshal(content, &fm); err != nil { + return nil, fmt.Errorf("failed to parse YAML file: %w", err) + } + + return &fm, nil +} + +// parseSQLFrontmatter extracts and parses the YAML frontmatter from a SQL file. +func parseSQLFrontmatter(sqlPath string) (*sqlFrontmatter, error) { + file, err := os.Open(sqlPath) + if err != nil { + return nil, fmt.Errorf("failed to open file: %w", err) + } + defer file.Close() + + scanner := bufio.NewScanner(file) + + var yamlContent strings.Builder + + inFrontmatter := false + foundStart := false + + for scanner.Scan() { + line := scanner.Text() + + if strings.TrimSpace(line) == "---" { + if !foundStart { + foundStart = true + inFrontmatter = true + + continue + } + // Found end of frontmatter + break + } + + if inFrontmatter { + yamlContent.WriteString(line) + yamlContent.WriteString("\n") + } + } + + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("error reading file: %w", err) + } + + if !foundStart { + return nil, fmt.Errorf("no YAML frontmatter found in file") + } + + var fm sqlFrontmatter + + if err := yaml.Unmarshal([]byte(yamlContent.String()), &fm); err != nil { + return nil, fmt.Errorf("failed to parse YAML frontmatter: %w", err) + } + + return &fm, nil +} + +// GetExternalModelIntervalType returns the interval type for an external model. +// Returns IntervalTypeEntity for dimension tables, or the actual type (slot, block, etc.). +func GetExternalModelIntervalType(model, xatuCBTPath string) (IntervalType, error) { + modelPath := findModelFile(xatuCBTPath, "external", model) + if modelPath == "" { + return "", fmt.Errorf("external model '%s' not found", model) + } + + fm, err := parseFrontmatter(modelPath) + if err != nil { + return "", fmt.Errorf("failed to parse frontmatter: %w", err) + } + + if fm.Interval.Type == "" { + // Default to slot if not specified + return IntervalTypeSlot, nil + } + + return fm.Interval.Type, nil +} + +// IsEntityModel checks if an external model is an entity/dimension table. +func IsEntityModel(model, xatuCBTPath string) bool { + intervalType, err := GetExternalModelIntervalType(model, xatuCBTPath) + if err != nil { + return false + } + + return intervalType == IntervalTypeEntity +} + +// GetExternalModelIntervalTypes returns interval types for multiple external models. +func GetExternalModelIntervalTypes(models []string, xatuCBTPath string) (map[string]IntervalType, error) { + result := make(map[string]IntervalType, len(models)) + + for _, model := range models { + intervalType, err := GetExternalModelIntervalType(model, xatuCBTPath) + if err != nil { + return nil, fmt.Errorf("failed to get interval type for %s: %w", model, err) + } + + result[model] = intervalType + } + + return result, nil +} + +// parseDependencyString parses a dependency string like "{{transformation}}.model_name". +func parseDependencyString(depStr string) (Dependency, error) { + matches := dependencyPattern.FindStringSubmatch(depStr) + if matches == nil { + return Dependency{}, fmt.Errorf("does not match expected pattern {{type}}.name") + } + + depType := DependencyType(matches[1]) + if depType != DependencyTypeExternal && depType != DependencyTypeTransformation { + return Dependency{}, fmt.Errorf("unknown dependency type: %s", matches[1]) + } + + return Dependency{ + Type: depType, + Name: matches[2], + }, nil +} diff --git a/pkg/seeddata/discovery.go b/pkg/seeddata/discovery.go new file mode 100644 index 0000000..3519964 --- /dev/null +++ b/pkg/seeddata/discovery.go @@ -0,0 +1,1257 @@ +package seeddata + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "os/exec" + "path/filepath" + "regexp" + "strings" + "time" + + "github.com/sirupsen/logrus" + "gopkg.in/yaml.v3" +) + +// RangeColumnType identifies the semantic type of a range column. +type RangeColumnType string + +const ( + // RangeColumnTypeTime represents DateTime columns like slot_start_date_time. + RangeColumnTypeTime RangeColumnType = "time" + // RangeColumnTypeBlock represents block number columns (UInt64/Int64). + RangeColumnTypeBlock RangeColumnType = "block" + // RangeColumnTypeSlot represents slot number columns (UInt64/Int64). + RangeColumnTypeSlot RangeColumnType = "slot" + // RangeColumnTypeEpoch represents epoch number columns. + RangeColumnTypeEpoch RangeColumnType = "epoch" + // RangeColumnTypeNone represents dimension/reference tables with no time-based range. + RangeColumnTypeNone RangeColumnType = "none" + // RangeColumnTypeUnknown represents an unclassified column type. + RangeColumnTypeUnknown RangeColumnType = "unknown" +) + +// TableRangeStrategy describes how to filter a single table for seed data extraction. +type TableRangeStrategy struct { + Model string `yaml:"model"` + RangeColumn string `yaml:"rangeColumn"` + ColumnType RangeColumnType `yaml:"columnType"` + FromValue string `yaml:"fromValue"` + ToValue string `yaml:"toValue"` + FilterSQL string `yaml:"filterSql,omitempty"` + CorrelationFilter string `yaml:"correlationFilter,omitempty"` // Subquery filter for dimension tables + Optional bool `yaml:"optional,omitempty"` // True if table is optional (LEFT JOIN) + RequiresBridge bool `yaml:"requiresBridge"` + BridgeTable string `yaml:"bridgeTable,omitempty"` + BridgeJoinSQL string `yaml:"bridgeJoinSql,omitempty"` + Confidence float64 `yaml:"confidence"` + Reasoning string `yaml:"reasoning"` +} + +// DiscoveryResult contains the complete AI-generated range strategy. +type DiscoveryResult struct { + PrimaryRangeType RangeColumnType `yaml:"primaryRangeType"` + PrimaryRangeColumn string `yaml:"primaryRangeColumn"` + FromValue string `yaml:"fromValue"` + ToValue string `yaml:"toValue"` + Strategies []TableRangeStrategy `yaml:"strategies"` + OverallConfidence float64 `yaml:"overallConfidence"` + Summary string `yaml:"summary"` + Warnings []string `yaml:"warnings,omitempty"` +} + +// GetStrategy returns the strategy for a specific model. +// Uses case-insensitive matching and trims whitespace to handle variations in Claude's output. +func (d *DiscoveryResult) GetStrategy(model string) *TableRangeStrategy { + modelLower := strings.ToLower(strings.TrimSpace(model)) + + for i := range d.Strategies { + strategyModel := strings.ToLower(strings.TrimSpace(d.Strategies[i].Model)) + if strategyModel == modelLower { + return &d.Strategies[i] + } + } + + return nil +} + +// TableSchemaInfo contains schema information for a table. +type TableSchemaInfo struct { + Model string `yaml:"model"` + IntervalType IntervalType `yaml:"intervalType,omitempty"` // From model frontmatter (slot, block, entity) + Columns []ColumnInfo `yaml:"columns"` + SampleData []map[string]any `yaml:"sampleData,omitempty"` + RangeInfo *DetectedRange `yaml:"rangeInfo,omitempty"` +} + +// DetectedRange contains detected range information. +type DetectedRange struct { + Column string `yaml:"column"` + ColumnType RangeColumnType `yaml:"type"` + Detected bool `yaml:"detected"` + MinValue string `yaml:"minValue,omitempty"` + MaxValue string `yaml:"maxValue,omitempty"` +} + +// DiscoveryInput contains all information gathered for Claude analysis. +type DiscoveryInput struct { + TransformationModel string `yaml:"transformationModel"` + TransformationSQL string `yaml:"transformationSql"` + IntermediateModels []IntermediateSQL `yaml:"intermediateModels,omitempty"` // SQL for intermediate deps + Network string `yaml:"network"` + Duration string `yaml:"duration"` // e.g., "5m", "10m", "1h" + ExternalModels []TableSchemaInfo `yaml:"externalModels"` +} + +// IntermediateSQL contains SQL for an intermediate transformation model. +type IntermediateSQL struct { + Model string `yaml:"model"` + SQL string `yaml:"sql"` +} + +// ClaudeDiscoveryClient handles AI-assisted range discovery. +type ClaudeDiscoveryClient struct { + log logrus.FieldLogger + claudePath string + timeout time.Duration + gen *Generator +} + +// NewClaudeDiscoveryClient creates a new discovery client. +func NewClaudeDiscoveryClient(log logrus.FieldLogger, gen *Generator) (*ClaudeDiscoveryClient, error) { + claudePath, err := findClaudeBinaryPath() + if err != nil { + return nil, fmt.Errorf("claude CLI not found: %w", err) + } + + return &ClaudeDiscoveryClient{ + log: log.WithField("component", "claude-discovery"), + claudePath: claudePath, + timeout: 5 * time.Minute, // Discovery can take longer than assertions + gen: gen, + }, nil +} + +// IsAvailable checks if Claude CLI is accessible. +func (c *ClaudeDiscoveryClient) IsAvailable() bool { + if c.claudePath == "" { + return false + } + + info, err := os.Stat(c.claudePath) + if err != nil { + return false + } + + return !info.IsDir() && info.Mode()&0111 != 0 +} + +// GatherSchemaInfo collects schema information for all external models. +// All tables are treated equally - Claude will analyze the schema to determine +// the best filtering strategy for each table. +func (c *ClaudeDiscoveryClient) GatherSchemaInfo( + ctx context.Context, + models []string, + network string, + xatuCBTPath string, +) ([]TableSchemaInfo, error) { + schemas := make([]TableSchemaInfo, 0, len(models)) + + for _, model := range models { + c.log.WithField("model", model).Debug("gathering schema info") + + // Get interval type from model frontmatter (informational context for Claude) + intervalType, err := GetExternalModelIntervalType(model, xatuCBTPath) + if err != nil { + c.log.WithError(err).WithField("model", model).Warn("failed to get interval type") + + intervalType = IntervalTypeSlot // Default to slot + } + + // Get column schema from ClickHouse + columns, err := c.gen.DescribeTable(ctx, model) + if err != nil { + return nil, fmt.Errorf("failed to describe table %s: %w", model, err) + } + + schemaInfo := TableSchemaInfo{ + Model: model, + IntervalType: intervalType, + Columns: columns, + } + + // Try to detect range column from SQL file + rangeCol, detectErr := DetectRangeColumnForModel(model, xatuCBTPath) + if detectErr != nil { + c.log.WithError(detectErr).WithField("model", model).Debug("failed to detect range column from SQL") + + // For tables without a detected range column, find any time column in schema + // This handles entity tables and other tables without explicit range definitions + rangeCol = findTimeColumnInSchema(columns) + if rangeCol == "" { + rangeCol = DefaultRangeColumn // Last resort fallback + } + } + + // Classify the range column type + colType := ClassifyRangeColumn(rangeCol, columns) + + // Query the range for this model + var minVal, maxVal string + + modelRange, rangeErr := c.gen.QueryModelRange(ctx, model, network, rangeCol) + if rangeErr != nil { + c.log.WithError(rangeErr).WithField("model", model).Warn("failed to query model range") + } else { + minVal = modelRange.MinRaw + maxVal = modelRange.MaxRaw + } + + // Get sample data (limited to 3 rows for prompt size) + sampleData, sampleErr := c.gen.QueryTableSample(ctx, model, network, 3) + if sampleErr != nil { + c.log.WithError(sampleErr).WithField("model", model).Warn("failed to query sample data") + // Continue without sample data - not critical + } + + schemaInfo.SampleData = sampleData + schemaInfo.RangeInfo = &DetectedRange{ + Column: rangeCol, + ColumnType: colType, + Detected: detectErr == nil, // True if detected from SQL, false if fallback + MinValue: minVal, + MaxValue: maxVal, + } + + schemas = append(schemas, schemaInfo) + } + + return schemas, nil +} + +// AnalyzeRanges invokes Claude to analyze range strategies. +func (c *ClaudeDiscoveryClient) AnalyzeRanges( + ctx context.Context, + input DiscoveryInput, +) (*DiscoveryResult, error) { + if !c.IsAvailable() { + return nil, fmt.Errorf("claude CLI is not available") + } + + prompt := c.buildDiscoveryPrompt(input) + + ctx, cancel := context.WithTimeout(ctx, c.timeout) + defer cancel() + + //nolint:gosec // claudePath is validated in findClaudeBinaryPath + cmd := exec.CommandContext(ctx, c.claudePath, "--print") + cmd.Stdin = strings.NewReader(prompt) + + var stdout, stderr bytes.Buffer + + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + c.log.WithFields(logrus.Fields{ + "timeout": c.timeout, + "model": input.TransformationModel, + }).Debug("invoking Claude CLI for range discovery") + + if err := cmd.Run(); err != nil { + if ctx.Err() == context.DeadlineExceeded { + return nil, fmt.Errorf("claude discovery timed out after %s", c.timeout) + } + + return nil, fmt.Errorf("claude CLI failed: %w (stderr: %s)", err, stderr.String()) + } + + response := stdout.String() + if response == "" { + return nil, fmt.Errorf("claude returned empty response") + } + + c.log.WithField("response_length", len(response)).Debug("received Claude response") + + return c.parseDiscoveryResponse(response) +} + +// buildDiscoveryPrompt constructs the prompt for Claude. +func (c *ClaudeDiscoveryClient) buildDiscoveryPrompt(input DiscoveryInput) string { + var sb strings.Builder + + sb.WriteString("## Task\n") + sb.WriteString("Analyze the following ClickHouse tables and determine the best strategy for extracting correlated seed data across all tables for testing.\n\n") + + sb.WriteString("## Context\n") + sb.WriteString(fmt.Sprintf("- Transformation Model: %s\n", input.TransformationModel)) + sb.WriteString(fmt.Sprintf("- Network: %s\n", input.Network)) + sb.WriteString(fmt.Sprintf("- Requested Duration: %s\n", input.Duration)) + sb.WriteString("- Goal: Extract a consistent slice of data from all external models that can be used together for testing the transformation\n\n") + + sb.WriteString("## Problem\n") + sb.WriteString("These tables may use different range column types:\n") + sb.WriteString("- Time-based columns (slot_start_date_time - DateTime)\n") + sb.WriteString("- Numeric columns (block_number - UInt64, slot - UInt64)\n") + sb.WriteString("- **Dimension/reference tables** (no time range - static lookup data like validator entities)\n\n") + sb.WriteString("You need to determine how to correlate these ranges so we get matching data across all tables.\n\n") + + sb.WriteString("**CRITICAL**: The transformation and its intermediate dependencies may have WHERE clauses that filter data.\n") + sb.WriteString("If you extract seed data that doesn't match these filters, the transformation will produce ZERO output rows!\n") + sb.WriteString("You MUST analyze the SQL and include any necessary filters in `filterSql` for each external model.\n\n") + + sb.WriteString("**IMPORTANT**: ALL tables must be filtered to limit data volume. Look at each table's schema to find appropriate filter columns.\n\n") + + sb.WriteString("**DIMENSION/ENTITY TABLES**: For tables that are JOINed as lookups (like validator entities):\n") + sb.WriteString("1. Analyze the JOIN condition in the transformation SQL to find the join key (e.g., validator_index)\n") + sb.WriteString("2. Use `correlationFilter` to filter by values that exist in the primary data tables\n") + sb.WriteString("3. **IMPORTANT**: Use `GLOBAL IN` (not just `IN`) for subqueries - ClickHouse requires this for distributed tables\n") + sb.WriteString("4. Example: If attestations JOIN on validator_index, filter entities to only those validators\n") + sb.WriteString("5. Mark tables as `optional: true` if the transformation can produce output without them (LEFT JOINs)\n") + sb.WriteString("6. If correlation isn't possible, use a reasonable time-based filter on any available DateTime column\n\n") + + sb.WriteString("## External Models and Their Schemas\n\n") + + for _, schema := range input.ExternalModels { + sb.WriteString(fmt.Sprintf("### %s\n", schema.Model)) + + // Show interval type from frontmatter (informational context for Claude) + if schema.IntervalType != "" { + sb.WriteString(fmt.Sprintf("Interval Type: %s\n", schema.IntervalType)) + } + + if schema.RangeInfo != nil { + sb.WriteString(fmt.Sprintf("Detected Range Column: %s (type: %s)\n", + schema.RangeInfo.Column, schema.RangeInfo.ColumnType)) + + if schema.RangeInfo.MinValue != "" && schema.RangeInfo.MaxValue != "" { + sb.WriteString(fmt.Sprintf("Available Range: %s to %s\n", + schema.RangeInfo.MinValue, schema.RangeInfo.MaxValue)) + } + } + + sb.WriteString("\nColumns:\n") + + for _, col := range schema.Columns { + sb.WriteString(fmt.Sprintf("- %s: %s\n", col.Name, col.Type)) + } + + if len(schema.SampleData) > 0 { + sb.WriteString("\nSample Data (first rows):\n```yaml\n") + + sampleYAML, err := yaml.Marshal(schema.SampleData) + if err == nil { + sb.WriteString(string(sampleYAML)) + } + + sb.WriteString("```\n") + } + + sb.WriteString("\n") + } + + sb.WriteString("## Transformation SQL\n```sql\n") + sb.WriteString(input.TransformationSQL) + sb.WriteString("\n```\n\n") + + // Include intermediate dependency SQL if available + if len(input.IntermediateModels) > 0 { + sb.WriteString("## Intermediate Dependency SQL\n") + sb.WriteString("The transformation depends on these intermediate models. Their WHERE clauses affect which seed data is usable:\n\n") + + for _, intermediate := range input.IntermediateModels { + sb.WriteString(fmt.Sprintf("### %s\n```sql\n%s\n```\n\n", intermediate.Model, intermediate.SQL)) + } + } + + sb.WriteString("## Instructions\n") + sb.WriteString("1. Analyze which tables can share a common range column directly\n") + sb.WriteString("2. For tables with different range types, determine if correlation is possible via:\n") + sb.WriteString(" - Direct conversion (e.g., slot to slot_start_date_time via calculation)\n") + sb.WriteString(" - Bridge table (e.g., canonical_beacon_block has both slot and execution block info)\n") + sb.WriteString(" - Shared columns in the data itself\n") + sb.WriteString("3. Recommend a primary range specification (type + column + from/to values)\n") + sb.WriteString(fmt.Sprintf("4. Use a %s time range (as requested by the user)\n", input.Duration)) + sb.WriteString("5. For each table, specify exactly how to filter it\n") + sb.WriteString("6. **CRITICAL**: Analyze ALL WHERE clauses in the transformation and intermediate SQL.\n") + sb.WriteString(" For each external model, identify any column filters that must be applied to get usable data.\n") + sb.WriteString(" Include these as `filterSql` - a SQL fragment like \"aggregation_bits = ''\" or \"attesting_validator_index IS NOT NULL\"\n\n") + + sb.WriteString("## Output Format\n") + sb.WriteString("Output ONLY valid YAML matching this EXACT structure.\n\n") + sb.WriteString("**CRITICAL FORMATTING RULES**:\n") + sb.WriteString("1. ALL field names MUST use proper camelCase (e.g., `primaryRangeColumn`, `fromValue`, `filterSql`)\n") + sb.WriteString("2. All datetime values MUST be quoted: `fromValue: \"2025-01-01 00:00:00\"`\n") + sb.WriteString("3. Output ONLY the YAML - no markdown code blocks, no explanations\n\n") + sb.WriteString("```yaml\n") + sb.WriteString("primaryRangeType: time\n") + sb.WriteString("primaryRangeColumn: slot_start_date_time\n") + sb.WriteString("fromValue: \"2025-01-01 00:00:00\"\n") + sb.WriteString("toValue: \"2025-01-01 00:05:00\"\n") + sb.WriteString("strategies:\n") + sb.WriteString(" - model: beacon_api_eth_v1_events_attestation\n") + sb.WriteString(" rangeColumn: slot_start_date_time\n") + sb.WriteString(" columnType: time\n") + sb.WriteString(" fromValue: \"2025-01-01 00:00:00\"\n") + sb.WriteString(" toValue: \"2025-01-01 00:05:00\"\n") + sb.WriteString(" filterSql: \"aggregation_bits = '' AND attesting_validator_index IS NOT NULL\"\n") + sb.WriteString(" requiresBridge: false\n") + sb.WriteString(" confidence: 0.9\n") + sb.WriteString(" reasoning: \"Filters from intermediate SQL\"\n") + sb.WriteString(" - model: ethseer_validator_entity\n") + sb.WriteString(" rangeColumn: \"\"\n") + sb.WriteString(" columnType: none\n") + sb.WriteString(" fromValue: \"\"\n") + sb.WriteString(" toValue: \"\"\n") + sb.WriteString(" filterSql: \"\"\n") + sb.WriteString(" correlationFilter: \"validator_index GLOBAL IN (SELECT DISTINCT attesting_validator_index FROM default.beacon_api_eth_v1_events_attestation WHERE slot_start_date_time >= toDateTime('2025-01-01 00:00:00') AND slot_start_date_time <= toDateTime('2025-01-01 00:05:00') AND meta_network_name = 'mainnet')\"\n") + sb.WriteString(" optional: true\n") + sb.WriteString(" requiresBridge: false\n") + sb.WriteString(" confidence: 0.9\n") + sb.WriteString(" reasoning: \"Entity table filtered by correlation - only validators appearing in attestation data\"\n") + sb.WriteString("overallConfidence: 0.85\n") + sb.WriteString("summary: \"Time-based primary range with filters from dependencies\"\n") + sb.WriteString("warnings:\n") + sb.WriteString(" - \"Filters applied to ensure usable seed data\"\n") + sb.WriteString("```\n\n") + + sb.WriteString("IMPORTANT:\n") + sb.WriteString("- Use actual values from the available ranges shown above\n") + sb.WriteString("- Pick a recent time window (last hour or so) within the intersection of all available ranges\n") + sb.WriteString("- For block_number tables, estimate block numbers that correspond to the chosen time window\n") + sb.WriteString("- Include ALL external models in the strategies list\n") + sb.WriteString("- **ANALYZE ALL WHERE CLAUSES** in transformation and intermediate SQL - missing filters will cause empty test output!\n") + sb.WriteString("- Include `filterSql` for each model (empty string if no additional filters needed)\n\n") + + sb.WriteString("**YOUR RESPONSE MUST START WITH `primaryRangeType:` - no preamble, no explanations, no markdown, just the raw YAML.**\n") + + return sb.String() +} + +// parseDiscoveryResponse parses Claude's YAML response. +func (c *ClaudeDiscoveryClient) parseDiscoveryResponse(response string) (*DiscoveryResult, error) { + // Extract YAML from response + yamlContent := extractYAMLFromResponse(response) + if yamlContent == "" { + // Log the raw response for debugging + c.log.WithField("response_preview", truncateString(response, 500)).Error("no valid YAML found in Claude response") + + return nil, fmt.Errorf("no valid YAML found in Claude response") + } + + // Normalize field names (Claude may output snake_case instead of camelCase) + beforeNorm := yamlContent + yamlContent = normalizeDiscoveryYAMLFields(yamlContent) + + if beforeNorm != yamlContent { + c.log.Debug("YAML normalization applied field name corrections") + } + + c.log.WithField("yaml_preview", truncateString(yamlContent, 300)).Debug("extracted YAML content") + + var result DiscoveryResult + + if err := yaml.Unmarshal([]byte(yamlContent), &result); err != nil { + c.log.WithFields(logrus.Fields{ + "error": err, + "yaml_content": yamlContent, + }).Error("failed to parse discovery YAML") + + // Include YAML preview in error for UI visibility + yamlPreview := yamlContent + if len(yamlPreview) > 800 { + yamlPreview = yamlPreview[:800] + "..." + } + + return nil, fmt.Errorf("failed to parse discovery YAML: %w\n\nClaude's output:\n%s", err, yamlPreview) + } + + // Validate result + if err := c.validateDiscoveryResult(&result); err != nil { + c.log.WithFields(logrus.Fields{ + "error": err, + "yaml_content": yamlContent, // Full content for debugging + "parsed": result, + }).Warn("invalid discovery result - showing YAML for debugging") + + // Include YAML preview in error for UI visibility + yamlPreview := yamlContent + if len(yamlPreview) > 500 { + yamlPreview = yamlPreview[:500] + "..." + } + + return nil, fmt.Errorf("invalid discovery result: %w\n\nClaude's YAML output:\n%s", err, yamlPreview) + } + + return &result, nil +} + +// normalizeDiscoveryYAMLFields converts common field name variations to expected camelCase +// and fixes common YAML formatting issues in Claude's output. +func normalizeDiscoveryYAMLFields(yamlContent string) string { + // Map of various field name formats to expected camelCase + // Includes snake_case, PascalCase, typos, and other variations Claude might output + replacements := map[string]string{ + // snake_case variations + "primary_range_type:": "primaryRangeType:", + "primary_range_column:": "primaryRangeColumn:", + // Common typos (missing capital letters) + "primaryrangeType:": "primaryRangeType:", + "primaryrangeColumn:": "primaryRangeColumn:", + "primaryRangetype:": "primaryRangeType:", + "primaryRangecolumn:": "primaryRangeColumn:", + "from_value:": "fromValue:", + "to_value:": "toValue:", + "range_column:": "rangeColumn:", + "column_type:": "columnType:", + "filter_sql:": "filterSql:", + "correlation_filter:": "correlationFilter:", + "requires_bridge:": "requiresBridge:", + "bridge_table:": "bridgeTable:", + "bridge_join_sql:": "bridgeJoinSql:", + "overall_confidence:": "overallConfidence:", + // PascalCase variations + "PrimaryRangeType:": "primaryRangeType:", + "PrimaryRangeColumn:": "primaryRangeColumn:", + "FromValue:": "fromValue:", + "ToValue:": "toValue:", + "RangeColumn:": "rangeColumn:", + "ColumnType:": "columnType:", + "FilterSql:": "filterSql:", + "FilterSQL:": "filterSql:", + "CorrelationFilter:": "correlationFilter:", + "RequiresBridge:": "requiresBridge:", + "BridgeTable:": "bridgeTable:", + "BridgeJoinSql:": "bridgeJoinSql:", + "BridgeJoinSQL:": "bridgeJoinSql:", + "OverallConfidence:": "overallConfidence:", + // Common typos/variations + "filterSql:": "filterSql:", + "filter:": "filterSql:", // Claude might shorten this + } + + result := yamlContent + for variant, camel := range replacements { + result = strings.ReplaceAll(result, variant, camel) + } + + // Fix unquoted datetime values (e.g., "fromValue: 2025-01-01 00:00:00" -> "fromValue: \"2025-01-01 00:00:00\"") + result = fixUnquotedDatetimes(result) + + return result +} + +// fixUnquotedDatetimes finds unquoted datetime values and adds quotes. +// Matches patterns like "fromValue: 2025-01-01 00:00:00" where the datetime is not quoted. +func fixUnquotedDatetimes(yamlContent string) string { + // Pattern: key followed by colon, space, then YYYY-MM-DD HH:MM:SS (not already quoted) + // We need to be careful not to double-quote already quoted values + lines := strings.Split(yamlContent, "\n") + result := make([]string, 0, len(lines)) + + datetimePattern := regexp.MustCompile(`^(\s*\w+:\s*)(\d{4}-\d{2}-\d{2}\s+\d{2}:\d{2}:\d{2})(\s*#.*)?$`) + + for _, line := range lines { + matches := datetimePattern.FindStringSubmatch(line) + if matches != nil { + // Found unquoted datetime - add quotes + prefix := matches[1] + datetime := matches[2] + + suffix := "" + if len(matches) > 3 { + suffix = matches[3] + } + + line = prefix + "\"" + datetime + "\"" + suffix + } + + result = append(result, line) + } + + return strings.Join(result, "\n") +} + +// truncateString truncates a string to maxLen characters, adding "..." if truncated. +func truncateString(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + + return s[:maxLen] + "..." +} + +// validateDiscoveryResult checks if the AI result is valid and complete. +func (c *ClaudeDiscoveryClient) validateDiscoveryResult(result *DiscoveryResult) error { + if result.PrimaryRangeColumn == "" { + return fmt.Errorf("primary_range_column is required") + } + + if result.FromValue == "" || result.ToValue == "" { + return fmt.Errorf("from_value and to_value are required") + } + + if len(result.Strategies) == 0 { + return fmt.Errorf("at least one strategy is required") + } + + for i, s := range result.Strategies { + if s.Model == "" { + return fmt.Errorf("strategy %d: model is required", i) + } + + // Tables can be filtered by: + // 1. Range column (rangeColumn + fromValue/toValue) + // 2. Correlation filter (subquery) + // 3. None type (dimension table - accepts all or filtered by filterSQL) + hasRangeFilter := s.RangeColumn != "" && s.FromValue != "" && s.ToValue != "" + hasCorrelationFilter := s.CorrelationFilter != "" + isNoneType := s.ColumnType == RangeColumnTypeNone + + // Must have at least one filtering mechanism + if !hasRangeFilter && !hasCorrelationFilter && !isNoneType { + return fmt.Errorf("strategy %d (%s): requires range filter (rangeColumn+from/to), correlationFilter, or columnType: none", i, s.Model) + } + + // If range column is specified, from/to are required + if s.RangeColumn != "" && !hasCorrelationFilter && (s.FromValue == "" || s.ToValue == "") { + return fmt.Errorf("strategy %d (%s): from_value and to_value are required when range_column is set without correlationFilter", i, s.Model) + } + } + + return nil +} + +// findTimeColumnInSchema looks for a time-based column in the schema. +// Prefers columns with "date_time" in the name, falls back to any DateTime column. +func findTimeColumnInSchema(columns []ColumnInfo) string { + // First, look for columns with "date_time" in the name (most common pattern) + for _, col := range columns { + colLower := strings.ToLower(col.Name) + if strings.Contains(colLower, "date_time") { + return col.Name + } + } + + // Fall back to any DateTime column + for _, col := range columns { + typeLower := strings.ToLower(col.Type) + if strings.Contains(typeLower, "datetime") { + return col.Name + } + } + + return "" +} + +// contains checks if a string slice contains a value. +func contains(slice []string, value string) bool { + for _, v := range slice { + if v == value { + return true + } + } + + return false +} + +// ClassifyRangeColumn determines the semantic type of a range column based on its name and schema type. +func ClassifyRangeColumn(column string, schema []ColumnInfo) RangeColumnType { + colLower := strings.ToLower(column) + + // Check by column name patterns + switch { + case strings.Contains(colLower, "date_time") || strings.Contains(colLower, "datetime"): + return RangeColumnTypeTime + case strings.Contains(colLower, "timestamp"): + return RangeColumnTypeTime + case colLower == "block_number" || strings.HasSuffix(colLower, "_block_number"): + return RangeColumnTypeBlock + case colLower == "slot" || strings.HasSuffix(colLower, "_slot"): + return RangeColumnTypeSlot + case colLower == "epoch" || strings.HasSuffix(colLower, "_epoch"): + return RangeColumnTypeEpoch + } + + // Check by schema type if available + for _, col := range schema { + if col.Name == column { + typeLower := strings.ToLower(col.Type) + if strings.Contains(typeLower, "datetime") { + return RangeColumnTypeTime + } + + break + } + } + + return RangeColumnTypeUnknown +} + +// QueryTableSample retrieves sample rows from a table for analysis. +func (g *Generator) QueryTableSample( + ctx context.Context, + model string, + network string, + limit int, +) ([]map[string]any, error) { + query := fmt.Sprintf(` + SELECT * + FROM default.%s + WHERE meta_network_name = '%s' + ORDER BY rand() + LIMIT %d + FORMAT JSON + `, model, network, limit) + + g.log.WithFields(logrus.Fields{ + "model": model, + "network": network, + "limit": limit, + }).Debug("querying table sample") + + chURL, err := g.buildClickHouseHTTPURL() + if err != nil { + return nil, fmt.Errorf("failed to build ClickHouse URL: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, chURL, strings.NewReader(query)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "text/plain") + + client := &http.Client{ + Timeout: 30 * time.Second, + } + + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to execute request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + + return nil, fmt.Errorf("ClickHouse returned status %d: %s", resp.StatusCode, string(body)) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + var jsonResp struct { + Data []map[string]any `json:"data"` + } + + if err := json.Unmarshal(body, &jsonResp); err != nil { + return nil, fmt.Errorf("failed to parse JSON response: %w", err) + } + + return jsonResp.Data, nil +} + +// categorizeModelsByType groups models into time, block, entity, and unknown categories. +// Uses interval types from frontmatter if available, falls back to column-based detection. +func categorizeModelsByType( + models []string, + intervalTypes map[string]IntervalType, + rangeInfos map[string]*RangeColumnInfo, +) (timeModels, blockModels, entityModels, unknownModels []string) { + timeModels = make([]string, 0) + blockModels = make([]string, 0) + entityModels = make([]string, 0) + unknownModels = make([]string, 0) + + for _, model := range models { + // First check frontmatter interval type (most accurate) + if intervalTypes != nil { + if intervalType, ok := intervalTypes[model]; ok { + switch intervalType { + case IntervalTypeEntity: + // Entity models need special handling - they have time columns but indexed by entity + entityModels = append(entityModels, model) + + continue + case IntervalTypeBlock: + blockModels = append(blockModels, model) + + continue + case IntervalTypeSlot: + timeModels = append(timeModels, model) + + continue + } + } + } + + // Fallback: use range column detection + info, ok := rangeInfos[model] + if !ok { + unknownModels = append(unknownModels, model) + + continue + } + + colLower := strings.ToLower(info.RangeColumn) + + switch { + case strings.Contains(colLower, "date_time") || strings.Contains(colLower, "timestamp"): + timeModels = append(timeModels, model) + case strings.Contains(colLower, "block"): + blockModels = append(blockModels, model) + default: + unknownModels = append(unknownModels, model) + } + } + + return timeModels, blockModels, entityModels, unknownModels +} + +// FallbackRangeDiscovery provides heuristic-based range discovery when Claude is unavailable. +func FallbackRangeDiscovery( + ctx context.Context, + gen *Generator, + models []string, + network string, + rangeInfos map[string]*RangeColumnInfo, + duration string, + xatuCBTPath string, +) (*DiscoveryResult, error) { + // Get interval types from model frontmatter for accurate categorization + intervalTypes, err := GetExternalModelIntervalTypes(models, xatuCBTPath) + if err != nil { + // Log warning but continue with column-based detection as fallback + gen.log.WithError(err).Warn("failed to get interval types from frontmatter, using column-based detection") + + intervalTypes = nil + } + + // Group models by interval type + timeModels, blockModels, entityModels, unknownModels := categorizeModelsByType(models, intervalTypes, rangeInfos) + + // Query ranges for models + var latestMin, earliestMax time.Time + + strategies := make([]TableRangeStrategy, 0, len(models)) + + for _, model := range models { + info := rangeInfos[model] + + // Check model category + isEntity := contains(entityModels, model) + isUnknown := contains(unknownModels, model) + + var rangeCol string + + var colType RangeColumnType + + // For entity models, find a time column in the schema + if isEntity { + // Query schema to find a time column + columns, schemaErr := gen.DescribeTable(ctx, model) + if schemaErr != nil { + gen.log.WithError(schemaErr).WithField("model", model).Warn("failed to describe entity table") + + strategies = append(strategies, TableRangeStrategy{ + Model: model, + ColumnType: RangeColumnTypeTime, + Confidence: 0.3, + Reasoning: fmt.Sprintf("Entity model (schema query failed: %v)", schemaErr), + }) + + continue + } + + timeCol := findTimeColumnInSchema(columns) + if timeCol == "" { + gen.log.WithField("model", model).Warn("no time column found for entity model") + + strategies = append(strategies, TableRangeStrategy{ + Model: model, + ColumnType: RangeColumnTypeTime, + Confidence: 0.3, + Reasoning: "Entity model - no time column found in schema", + }) + + continue + } + + rangeCol = timeCol + colType = RangeColumnTypeTime + } else if isUnknown { + // Unknown models - try default range column + rangeCol = DefaultRangeColumn + colType = RangeColumnTypeTime + } else { + // Time or block models - use detected range column + rangeCol = DefaultRangeColumn + if info != nil { + rangeCol = info.RangeColumn + } + + colType = ClassifyRangeColumn(rangeCol, nil) + } + + modelRange, queryErr := gen.QueryModelRange(ctx, model, network, rangeCol) + if queryErr != nil { + gen.log.WithError(queryErr).WithField("model", model).Warn("range query failed") + + strategies = append(strategies, TableRangeStrategy{ + Model: model, + RangeColumn: rangeCol, + ColumnType: colType, + Confidence: 0.3, + Reasoning: fmt.Sprintf("Range query failed: %v", queryErr), + }) + + continue + } + + if latestMin.IsZero() || modelRange.Min.After(latestMin) { + latestMin = modelRange.Min + } + + if earliestMax.IsZero() || modelRange.Max.Before(earliestMax) { + earliestMax = modelRange.Max + } + + reasoning := "Heuristic-based detection (Claude unavailable)" + if isEntity { + reasoning = fmt.Sprintf("Entity model - using %s for time filtering", rangeCol) + } + + strategies = append(strategies, TableRangeStrategy{ + Model: model, + RangeColumn: rangeCol, + ColumnType: colType, + Confidence: 0.7, + Reasoning: reasoning, + }) + } + + // Handle case where no models have valid ranges + hasRanges := !latestMin.IsZero() && !earliestMax.IsZero() + + var fromValue, toValue string + + var primaryType RangeColumnType + + var primaryColumn string + + if hasRanges { + // Check for intersection + if latestMin.After(earliestMax) { + return nil, fmt.Errorf("no intersecting range found across all models") + } + + // Parse duration string (e.g., "5m", "10m", "1h") + rangeDuration, parseErr := time.ParseDuration(duration) + if parseErr != nil { + rangeDuration = 5 * time.Minute // Default to 5 minutes if parsing fails + } + + // Use the last N minutes/hours of available data + effectiveMax := earliestMax.Add(-1 * time.Minute) // Account for ingestion lag + effectiveMin := effectiveMax.Add(-rangeDuration) + + if effectiveMin.Before(latestMin) { + effectiveMin = latestMin + } + + fromValue = effectiveMin.Format("2006-01-02 15:04:05") + toValue = effectiveMax.Format("2006-01-02 15:04:05") + + // Determine primary range type based on majority + if len(timeModels)+len(entityModels) >= len(blockModels) { + primaryType = RangeColumnTypeTime + primaryColumn = DefaultRangeColumn + } else { + primaryType = RangeColumnTypeBlock + primaryColumn = "block_number" + } + + // Update strategies with range values (skip strategies without valid range column) + for i := range strategies { + if strategies[i].RangeColumn != "" { + strategies[i].FromValue = fromValue + strategies[i].ToValue = toValue + } + } + } else { + // No valid ranges found - this is an error condition + return nil, fmt.Errorf("no valid range columns found for any model") + } + + warnings := make([]string, 0) + if len(blockModels) > 0 && len(timeModels) > 0 { + warnings = append(warnings, + "Mixed range column types detected (time and block). "+ + "Block-based tables may not correlate correctly with time-based filtering.") + } + + return &DiscoveryResult{ + PrimaryRangeType: primaryType, + PrimaryRangeColumn: primaryColumn, + FromValue: fromValue, + ToValue: toValue, + Strategies: strategies, + OverallConfidence: 0.6, // Lower confidence for heuristic + Summary: "Heuristic-based range detection (Claude unavailable)", + Warnings: warnings, + }, nil +} + +// ReadTransformationSQL reads the SQL file for a transformation model. +func ReadTransformationSQL(model, xatuCBTPath string) (string, error) { + sqlPath := filepath.Join(xatuCBTPath, "models", "transformations", model+".sql") + + content, err := os.ReadFile(sqlPath) + if err != nil { + return "", fmt.Errorf("failed to read transformation SQL: %w", err) + } + + return string(content), nil +} + +// ModelDataCount holds the row count validation result for a model. +type ModelDataCount struct { + Model string + Strategy *TableRangeStrategy + RowCount int64 + HasData bool + Error error +} + +// ValidationResult contains the results of validating a discovery strategy. +type ValidationResult struct { + Counts []ModelDataCount + AllHaveData bool + EmptyModels []string // Models with zero rows + ErroredModels []string // Models that failed to query (timeout, etc.) + TotalRows int64 + MinRowCount int64 + MinRowModel string +} + +// ValidateStrategyHasData queries each model to verify data exists in the proposed ranges. +func (g *Generator) ValidateStrategyHasData( + ctx context.Context, + result *DiscoveryResult, + network string, +) (*ValidationResult, error) { + counts := make([]ModelDataCount, 0, len(result.Strategies)) + emptyModels := make([]string, 0) + erroredModels := make([]string, 0) + + var totalRows int64 + + minRowCount := int64(-1) + minRowModel := "" + + for _, strategy := range result.Strategies { + count, err := g.QueryRowCount(ctx, strategy.Model, network, strategy.RangeColumn, strategy.FromValue, strategy.ToValue, strategy.FilterSQL, strategy.CorrelationFilter) + + modelCount := ModelDataCount{ + Model: strategy.Model, + Strategy: &strategy, + Error: err, + } + + if err != nil { + modelCount.HasData = false + + erroredModels = append(erroredModels, strategy.Model) + } else { + modelCount.RowCount = count + modelCount.HasData = count > 0 + totalRows += count + + if !modelCount.HasData { + emptyModels = append(emptyModels, strategy.Model) + } + + if minRowCount < 0 || count < minRowCount { + minRowCount = count + minRowModel = strategy.Model + } + } + + counts = append(counts, modelCount) + } + + return &ValidationResult{ + Counts: counts, + AllHaveData: len(emptyModels) == 0 && len(erroredModels) == 0, + EmptyModels: emptyModels, + ErroredModels: erroredModels, + TotalRows: totalRows, + MinRowCount: minRowCount, + MinRowModel: minRowModel, + }, nil +} + +// QueryRowCount queries the number of rows in a model for a given range. +// For dimension tables (empty rangeColumn), it counts all rows for the network. +func (g *Generator) QueryRowCount( + ctx context.Context, + model string, + network string, + rangeColumn string, + fromValue string, + toValue string, + filterSQL string, + correlationFilter string, +) (int64, error) { + // Build additional filter clause + filterClause := "" + if filterSQL != "" { + filterClause = fmt.Sprintf("\n AND %s", filterSQL) + } + + if correlationFilter != "" { + filterClause += fmt.Sprintf("\n AND %s", correlationFilter) + } + + var query string + + // Handle dimension tables (no range column) + if rangeColumn == "" { + query = fmt.Sprintf(` + SELECT COUNT(*) as cnt + FROM default.%s + WHERE meta_network_name = '%s'%s + FORMAT JSON + `, model, network, filterClause) + } else { + // Determine if this is a numeric or time-based column + isNumeric := !strings.Contains(strings.ToLower(rangeColumn), "date") && + !strings.Contains(strings.ToLower(rangeColumn), "time") + + if isNumeric { + query = fmt.Sprintf(` + SELECT COUNT(*) as cnt + FROM default.%s + WHERE meta_network_name = '%s' + AND %s >= %s + AND %s <= %s%s + FORMAT JSON + `, model, network, rangeColumn, fromValue, rangeColumn, toValue, filterClause) + } else { + query = fmt.Sprintf(` + SELECT COUNT(*) as cnt + FROM default.%s + WHERE meta_network_name = '%s' + AND %s >= toDateTime('%s') + AND %s <= toDateTime('%s')%s + FORMAT JSON + `, model, network, rangeColumn, fromValue, rangeColumn, toValue, filterClause) + } + } + + g.log.WithFields(logrus.Fields{ + "model": model, + "network": network, + "from": fromValue, + "to": toValue, + }).Debug("querying row count") + + chURL, err := g.buildClickHouseHTTPURL() + if err != nil { + return 0, fmt.Errorf("failed to build ClickHouse URL: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, chURL, strings.NewReader(query)) + if err != nil { + return 0, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "text/plain") + + client := &http.Client{ + Timeout: 2 * time.Minute, // Row count queries on large tables can take time + } + + resp, err := client.Do(req) + if err != nil { + return 0, fmt.Errorf("failed to execute request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + + return 0, fmt.Errorf("ClickHouse returned status %d: %s", resp.StatusCode, string(body)) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return 0, fmt.Errorf("failed to read response: %w", err) + } + + var jsonResp struct { + Data []map[string]any `json:"data"` + } + + if err := json.Unmarshal(body, &jsonResp); err != nil { + return 0, fmt.Errorf("failed to parse JSON response: %w", err) + } + + if len(jsonResp.Data) == 0 { + return 0, nil + } + + // Extract count from response + cntVal, ok := jsonResp.Data[0]["cnt"] + if !ok { + return 0, fmt.Errorf("cnt not found in response") + } + + // Handle both string and numeric types + switch v := cntVal.(type) { + case string: + var count int64 + + _, err := fmt.Sscanf(v, "%d", &count) + + return count, err + case float64: + return int64(v), nil + case int64: + return v, nil + default: + return 0, fmt.Errorf("unexpected count type: %T", cntVal) + } +} + +// ExpandWindowMultiplier defines how much to expand the window on each retry. +const ExpandWindowMultiplier = 2 + +// SuggestExpandedStrategy creates a new strategy with an expanded time window. +// This is used when the original strategy has models with no data. +func SuggestExpandedStrategy(original *DiscoveryResult, multiplier int) *DiscoveryResult { + expanded := &DiscoveryResult{ + PrimaryRangeType: original.PrimaryRangeType, + PrimaryRangeColumn: original.PrimaryRangeColumn, + OverallConfidence: original.OverallConfidence * 0.9, // Reduce confidence slightly + Summary: fmt.Sprintf("%s (window expanded %dx)", original.Summary, multiplier), + Warnings: append([]string{}, original.Warnings...), + Strategies: make([]TableRangeStrategy, len(original.Strategies)), + } + + // For time-based ranges, we can try to expand by parsing and adjusting + // For now, just copy and add a warning - the actual expansion would need + // to be done with knowledge of the original window size + copy(expanded.Strategies, original.Strategies) + expanded.Warnings = append(expanded.Warnings, + fmt.Sprintf("Window expanded %dx to find data - verify data quality", multiplier)) + + return expanded +} diff --git a/pkg/seeddata/generator.go b/pkg/seeddata/generator.go new file mode 100644 index 0000000..0b57742 --- /dev/null +++ b/pkg/seeddata/generator.go @@ -0,0 +1,466 @@ +// Package seeddata provides functionality to generate seed data parquet files +// for xatu-cbt tests by extracting data from external ClickHouse. +package seeddata + +import ( + "context" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + "time" + + "github.com/ethpandaops/xcli/pkg/config" + "github.com/sirupsen/logrus" +) + +const ( + schemeHTTPS = "https" +) + +// Generator handles seed data generation from external ClickHouse. +type Generator struct { + log logrus.FieldLogger + cfg *config.LabConfig +} + +// NewGenerator creates a new seed data generator. +func NewGenerator(log logrus.FieldLogger, cfg *config.LabConfig) *Generator { + return &Generator{ + log: log.WithField("component", "seeddata"), + cfg: cfg, + } +} + +// GenerateOptions contains options for generating seed data. +type GenerateOptions struct { + Model string // Table name (e.g., "beacon_api_eth_v1_events_block") + Network string // Network name (e.g., "mainnet", "sepolia") + Spec string // Fork spec (e.g., "pectra", "fusaka") + RangeColumn string // Column to filter on (e.g., "slot", "epoch") + From string // Range start value + To string // Range end value + Filters []Filter // Additional filters + FilterSQL string // Raw SQL fragment for additional WHERE conditions (from AI discovery) + CorrelationFilter string // Subquery filter for dimension tables (e.g., "validator_index IN (SELECT ...)") + Limit int // Max rows (0 = unlimited) + OutputPath string // Output file path + SanitizeIPs bool // Enable IP address sanitization + Salt string // Salt for IP sanitization (shared across batch for consistency) + + // sanitizedColumns is an internal field set by Generate() when SanitizeIPs is true. + // It contains the pre-computed column list with IP sanitization expressions. + sanitizedColumns string +} + +// Filter represents an additional WHERE clause filter. +type Filter struct { + Column string // Column name + Operator string // Operator (=, !=, >, <, >=, <=, LIKE, IN, etc.) + Value string // Value to compare against +} + +// GenerateResult contains the result of seed data generation. +type GenerateResult struct { + OutputPath string // Path to generated parquet file + RowCount int64 // Number of rows extracted (estimated from file size) + FileSize int64 // File size in bytes + SanitizedColumns []string // IP columns that were sanitized (for display to user) + Query string // SQL query used (for debugging) +} + +// Generate extracts data from external ClickHouse and writes to a parquet file. +func (g *Generator) Generate(ctx context.Context, opts GenerateOptions) (*GenerateResult, error) { + // Validate options + if opts.Model == "" { + return nil, fmt.Errorf("model is required") + } + + if opts.Network == "" { + return nil, fmt.Errorf("network is required") + } + + // Build output path if not specified + if opts.OutputPath == "" { + opts.OutputPath = fmt.Sprintf("./%s.parquet", opts.Model) + } + + // Ensure output directory exists + dir := filepath.Dir(opts.OutputPath) + if dir != "." && dir != "" { + if err := os.MkdirAll(dir, 0755); err != nil { + return nil, fmt.Errorf("failed to create output directory: %w", err) + } + } + + // Build sanitized column list if IP sanitization is enabled + var sanitizedColumns []string + + if opts.SanitizeIPs && opts.Salt != "" { + result, err := g.BuildSanitizedColumnList(ctx, opts.Model, opts.Salt) + if err != nil { + return nil, fmt.Errorf("failed to build sanitized column list: %w", err) + } + + opts.sanitizedColumns = result.ColumnExpr + sanitizedColumns = result.SanitizedColumns + } + + // Build the SQL query + query := g.buildQuery(opts) + + g.log.WithFields(logrus.Fields{ + "model": opts.Model, + "network": opts.Network, + "output": opts.OutputPath, + "range_column": opts.RangeColumn, + "from": opts.From, + "to": opts.To, + "query": query, + }).Info("generating seed data") + + // Execute query and stream to file + fileSize, err := g.executeQueryToFile(ctx, query, opts.OutputPath) + if err != nil { + return nil, fmt.Errorf("failed to execute query: %w", err) + } + + return &GenerateResult{ + OutputPath: opts.OutputPath, + FileSize: fileSize, + SanitizedColumns: sanitizedColumns, + Query: query, + }, nil +} + +// buildQuery constructs the SQL query for extracting seed data. +func (g *Generator) buildQuery(opts GenerateOptions) string { + var sb strings.Builder + + // Use sanitized column list if available, otherwise SELECT * + if opts.sanitizedColumns != "" { + sb.WriteString("SELECT ") + sb.WriteString(opts.sanitizedColumns) + sb.WriteString(" FROM default.") + } else { + sb.WriteString("SELECT * FROM default.") + } + + sb.WriteString(opts.Model) + sb.WriteString("\nWHERE meta_network_name = '") + sb.WriteString(opts.Network) + sb.WriteString("'") + + // Add range filter if specified + // Use column-name-based detection (same logic as validation query in discovery.go) + if opts.RangeColumn != "" && opts.From != "" && opts.To != "" { + colLower := strings.ToLower(opts.RangeColumn) + isTimeColumn := strings.Contains(colLower, "date") || strings.Contains(colLower, "time") + + sb.WriteString("\n AND ") + sb.WriteString(opts.RangeColumn) + sb.WriteString(" >= ") + + if isTimeColumn { + sb.WriteString(fmt.Sprintf("toDateTime('%s')", opts.From)) + } else { + sb.WriteString(opts.From) // Numeric value as-is + } + + sb.WriteString("\n AND ") + sb.WriteString(opts.RangeColumn) + sb.WriteString(" <= ") + + if isTimeColumn { + sb.WriteString(fmt.Sprintf("toDateTime('%s')", opts.To)) + } else { + sb.WriteString(opts.To) // Numeric value as-is + } + } + + // Add additional filters (structured) + for _, filter := range opts.Filters { + sb.WriteString("\n AND ") + sb.WriteString(filter.Column) + sb.WriteString(" ") + sb.WriteString(filter.Operator) + sb.WriteString(" ") + sb.WriteString(formatSQLValue(filter.Value)) + } + + // Add raw SQL filter if specified (from AI discovery) + if opts.FilterSQL != "" { + sb.WriteString("\n AND ") + sb.WriteString(opts.FilterSQL) + } + + // Add correlation filter if specified (subquery for dimension tables) + if opts.CorrelationFilter != "" { + sb.WriteString("\n AND ") + sb.WriteString(opts.CorrelationFilter) + } + + // Add limit if specified + if opts.Limit > 0 { + sb.WriteString(fmt.Sprintf("\nLIMIT %d", opts.Limit)) + } + + sb.WriteString("\nFORMAT Parquet") + + return sb.String() +} + +// formatSQLValue formats a value for use in SQL. +// Numeric values are returned as-is, datetime values use toDateTime(), other values are quoted. +func formatSQLValue(val string) string { + // Check if value is purely numeric (integer or decimal) + if isNumeric(val) { + return val + } + + // Check if value looks like a datetime (YYYY-MM-DD HH:MM:SS) + if isDateTime(val) { + return fmt.Sprintf("toDateTime('%s')", val) + } + + // Quote non-numeric values (strings, etc.) + // Escape single quotes by doubling them + escaped := strings.ReplaceAll(val, "'", "''") + + return "'" + escaped + "'" +} + +// isDateTime checks if a string looks like a datetime (YYYY-MM-DD HH:MM:SS). +func isDateTime(s string) bool { + // Must be exactly 19 characters: YYYY-MM-DD HH:MM:SS + if len(s) != 19 { + return false + } + + // Check format: YYYY-MM-DD HH:MM:SS + // Positions: 0123456789012345678 + // 2025-12-10 20:00:00 + if s[4] != '-' || s[7] != '-' || s[10] != ' ' || s[13] != ':' || s[16] != ':' { + return false + } + + // Check that other positions are digits + digitPositions := []int{0, 1, 2, 3, 5, 6, 8, 9, 11, 12, 14, 15, 17, 18} + for _, pos := range digitPositions { + if s[pos] < '0' || s[pos] > '9' { + return false + } + } + + return true +} + +// isNumeric checks if a string represents a numeric value. +func isNumeric(s string) bool { + if s == "" { + return false + } + + // Allow leading minus sign + start := 0 + if s[0] == '-' { + start = 1 + + if len(s) == 1 { + return false + } + } + + hasDecimal := false + + for i := start; i < len(s); i++ { + c := s[i] + if c == '.' { + if hasDecimal { + return false // Multiple decimals + } + + hasDecimal = true + + continue + } + + if c < '0' || c > '9' { + return false + } + } + + return true +} + +// executeQueryToFile executes a query and streams the result to a file. +func (g *Generator) executeQueryToFile(ctx context.Context, query, outputPath string) (int64, error) { + // Parse external ClickHouse URL + chURL, err := g.buildClickHouseHTTPURL() + if err != nil { + return 0, fmt.Errorf("failed to build ClickHouse URL: %w", err) + } + + g.log.WithField("query", query).Debug("executing query") + + // Create HTTP request + req, err := http.NewRequestWithContext(ctx, http.MethodPost, chURL, strings.NewReader(query)) + if err != nil { + return 0, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "text/plain") + + // Create HTTP client with timeout + client := &http.Client{ + Timeout: 10 * time.Minute, // Allow long queries + } + + // Execute request + resp, err := client.Do(req) + if err != nil { + return 0, fmt.Errorf("failed to execute request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + + return 0, fmt.Errorf("ClickHouse returned status %d: %s", resp.StatusCode, string(body)) + } + + // Create output file + outFile, err := os.Create(outputPath) + if err != nil { + return 0, fmt.Errorf("failed to create output file: %w", err) + } + defer outFile.Close() + + // Stream response to file + written, err := io.Copy(outFile, resp.Body) + if err != nil { + return 0, fmt.Errorf("failed to write output file: %w", err) + } + + return written, nil +} + +// buildClickHouseHTTPURL constructs the HTTP URL for ClickHouse queries. +func (g *Generator) buildClickHouseHTTPURL() (string, error) { + externalURL := g.cfg.Infrastructure.ClickHouse.Xatu.ExternalURL + + // Parse the configured URL + parsed, err := url.Parse(externalURL) + if err != nil { + return "", fmt.Errorf("failed to parse external URL: %w", err) + } + + // Convert to HTTP URL if needed + // The external URL might be in native protocol format (port 9000) + // We need HTTP port (typically 8123 or 8443 for HTTPS) + host := parsed.Hostname() + port := parsed.Port() + + // Determine scheme and port for HTTP API + scheme := "http" + + if parsed.Scheme == schemeHTTPS || parsed.Scheme == "clickhouses" { + scheme = schemeHTTPS + } + + // Map native port to HTTP port if needed + switch port { + case "9000": + port = "8123" // Default HTTP port + case "9440": + port = "8443" // Default HTTPS port + case "": + if scheme == schemeHTTPS { + port = "8443" + } else { + port = "8123" + } + } + + // Build HTTP URL with query parameters + httpURL := &url.URL{ + Scheme: scheme, + Host: fmt.Sprintf("%s:%s", host, port), + Path: "/", + } + + // Add authentication if configured + query := httpURL.Query() + + if g.cfg.Infrastructure.ClickHouse.Xatu.ExternalUsername != "" { + query.Set("user", g.cfg.Infrastructure.ClickHouse.Xatu.ExternalUsername) + } else if parsed.User != nil && parsed.User.Username() != "" { + query.Set("user", parsed.User.Username()) + } else { + query.Set("user", "default") + } + + if g.cfg.Infrastructure.ClickHouse.Xatu.ExternalPassword != "" { + query.Set("password", g.cfg.Infrastructure.ClickHouse.Xatu.ExternalPassword) + } else if parsed.User != nil { + if pass, ok := parsed.User.Password(); ok { + query.Set("password", pass) + } + } + + // Set database + if g.cfg.Infrastructure.ClickHouse.Xatu.ExternalDatabase != "" { + query.Set("database", g.cfg.Infrastructure.ClickHouse.Xatu.ExternalDatabase) + } else { + query.Set("database", "default") + } + + httpURL.RawQuery = query.Encode() + + return httpURL.String(), nil +} + +// ListExternalModels returns a list of available external models from the xatu-cbt repo. +func (g *Generator) ListExternalModels() ([]string, error) { + modelsDir := filepath.Join(g.cfg.Repos.XatuCBT, "models", "external") + + entries, err := os.ReadDir(modelsDir) + if err != nil { + return nil, fmt.Errorf("failed to read models directory: %w", err) + } + + models := make([]string, 0, len(entries)) + + for _, entry := range entries { + if entry.IsDir() { + continue + } + + name := entry.Name() + if strings.HasSuffix(name, ".sql") { + // Remove .sql extension to get model name + models = append(models, strings.TrimSuffix(name, ".sql")) + } + } + + return models, nil +} + +// ValidateModel checks if a model name is valid (exists in xatu-cbt external models). +func (g *Generator) ValidateModel(model string) error { + models, err := g.ListExternalModels() + if err != nil { + return fmt.Errorf("failed to list models: %w", err) + } + + for _, m := range models { + if m == model { + return nil + } + } + + return fmt.Errorf("model '%s' not found in xatu-cbt external models", model) +} diff --git a/pkg/seeddata/rangedetect.go b/pkg/seeddata/rangedetect.go new file mode 100644 index 0000000..dc9efd4 --- /dev/null +++ b/pkg/seeddata/rangedetect.go @@ -0,0 +1,135 @@ +package seeddata + +import ( + "bufio" + "fmt" + "os" + "path/filepath" + "regexp" + "strings" +) + +const ( + // DefaultRangeColumn is the fallback range column when detection fails. + DefaultRangeColumn = "slot_start_date_time" +) + +// rangeColumnPatterns are regex patterns to detect range columns from external model SQL. +// Order matters - more specific patterns should come first. +var rangeColumnPatterns = []*regexp.Regexp{ + // Pattern: toUnixTimestamp(min(column_name)) as min + regexp.MustCompile(`toUnixTimestamp\s*\(\s*min\s*\(\s*(\w+)\s*\)\s*\)`), + // Pattern: toUnixTimestamp(max(column_name)) as max + regexp.MustCompile(`toUnixTimestamp\s*\(\s*max\s*\(\s*(\w+)\s*\)\s*\)`), + // Pattern: min(column_name) as min + regexp.MustCompile(`(?:^|[^(])\bmin\s*\(\s*(\w+)\s*\)\s+as\s+min`), + // Pattern: max(column_name) as max + regexp.MustCompile(`(?:^|[^(])\bmax\s*\(\s*(\w+)\s*\)\s+as\s+max`), +} + +// DetectRangeColumn parses an external model SQL file to detect the range column +// used in bounds queries. It looks for patterns like toUnixTimestamp(min(column_name)). +// Falls back to DefaultRangeColumn if detection fails. +func DetectRangeColumn(externalModelPath string) (string, error) { + file, err := os.Open(externalModelPath) + if err != nil { + return "", fmt.Errorf("failed to open file: %w", err) + } + defer file.Close() + + scanner := bufio.NewScanner(file) + + var content strings.Builder + + for scanner.Scan() { + content.WriteString(scanner.Text()) + content.WriteString("\n") + } + + if err := scanner.Err(); err != nil { + return "", fmt.Errorf("error reading file: %w", err) + } + + sqlContent := strings.ToLower(content.String()) + + // Try each pattern to find the range column + for _, pattern := range rangeColumnPatterns { + matches := pattern.FindStringSubmatch(sqlContent) + if len(matches) > 1 { + return matches[1], nil + } + } + + // No pattern matched, return default + return DefaultRangeColumn, nil +} + +// DetectRangeColumnForModel detects the range column for an external model by name. +func DetectRangeColumnForModel(model, xatuCBTPath string) (string, error) { + modelPath := filepath.Join(xatuCBTPath, "models", "external", model+".sql") + + if _, err := os.Stat(modelPath); os.IsNotExist(err) { + return "", fmt.Errorf("external model '%s' not found", model) + } + + return DetectRangeColumn(modelPath) +} + +// RangeColumnInfo contains information about a model's range column. +type RangeColumnInfo struct { + Model string + RangeColumn string + Detected bool // true if detected from SQL, false if using default +} + +// DetectRangeColumnsForModels detects range columns for multiple external models. +// Returns a map of model name to range column info. +func DetectRangeColumnsForModels(models []string, xatuCBTPath string) (map[string]*RangeColumnInfo, error) { + result := make(map[string]*RangeColumnInfo, len(models)) + + for _, model := range models { + modelPath := filepath.Join(xatuCBTPath, "models", "external", model+".sql") + + if _, err := os.Stat(modelPath); os.IsNotExist(err) { + return nil, fmt.Errorf("external model '%s' not found", model) + } + + rangeCol, err := DetectRangeColumn(modelPath) + if err != nil { + return nil, fmt.Errorf("failed to detect range column for %s: %w", model, err) + } + + result[model] = &RangeColumnInfo{ + Model: model, + RangeColumn: rangeCol, + Detected: rangeCol != DefaultRangeColumn, + } + } + + return result, nil +} + +// FindCommonRangeColumn finds a common range column across all models. +// Returns the common column if all models share the same one, or DefaultRangeColumn otherwise. +func FindCommonRangeColumn(rangeInfos map[string]*RangeColumnInfo) string { + if len(rangeInfos) == 0 { + return DefaultRangeColumn + } + + var commonColumn string + + for _, info := range rangeInfos { + if commonColumn == "" { + commonColumn = info.RangeColumn + + continue + } + + if info.RangeColumn != commonColumn { + // Different range columns, fall back to default + return DefaultRangeColumn + } + } + + return commonColumn +} diff --git a/pkg/seeddata/ranges.go b/pkg/seeddata/ranges.go new file mode 100644 index 0000000..2a8bf95 --- /dev/null +++ b/pkg/seeddata/ranges.go @@ -0,0 +1,250 @@ +package seeddata + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +// ModelRange represents the available data range for a model. +type ModelRange struct { + Model string + Network string + RangeColumn string + Min time.Time + Max time.Time + MinRaw string // Original value from query (for display) + MaxRaw string +} + +// QueryModelRange queries external ClickHouse for a model's available data range. +// Uses ORDER BY ... LIMIT 1 instead of MIN/MAX for better performance on large tables. +func (g *Generator) QueryModelRange(ctx context.Context, model, network, rangeColumn string) (*ModelRange, error) { + // Query for minimum value (oldest data) + minQuery := fmt.Sprintf(` + SELECT %s as val + FROM default.%s + WHERE meta_network_name = '%s' + ORDER BY %s ASC + LIMIT 1 + FORMAT JSON + `, rangeColumn, model, network, rangeColumn) + + g.log.WithField("query", minQuery).Debug("querying model min range") + + minResult, err := g.executeSingleValueQuery(ctx, minQuery) + if err != nil { + return nil, fmt.Errorf("failed to query min range for %s: %w", model, err) + } + + // Query for maximum value (newest data) + maxQuery := fmt.Sprintf(` + SELECT %s as val + FROM default.%s + WHERE meta_network_name = '%s' + ORDER BY %s DESC + LIMIT 1 + FORMAT JSON + `, rangeColumn, model, network, rangeColumn) + + g.log.WithField("query", maxQuery).Debug("querying model max range") + + maxResult, err := g.executeSingleValueQuery(ctx, maxQuery) + if err != nil { + return nil, fmt.Errorf("failed to query max range for %s: %w", model, err) + } + + return &ModelRange{ + Model: model, + Network: network, + RangeColumn: rangeColumn, + Min: minResult.Time, + Max: maxResult.Time, + MinRaw: minResult.Raw, + MaxRaw: maxResult.Raw, + }, nil +} + +// singleValueResult holds a single time value result. +type singleValueResult struct { + Time time.Time + Raw string +} + +// executeSingleValueQuery executes a query that returns a single time value. +func (g *Generator) executeSingleValueQuery(ctx context.Context, query string) (*singleValueResult, error) { + chURL, err := g.buildClickHouseHTTPURL() + if err != nil { + return nil, fmt.Errorf("failed to build ClickHouse URL: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, chURL, strings.NewReader(query)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "text/plain") + + client := &http.Client{ + Timeout: 30 * time.Second, // Shorter timeout for indexed queries + } + + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to execute request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + + return nil, fmt.Errorf("ClickHouse returned status %d: %s", resp.StatusCode, string(body)) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + var jsonResp clickHouseJSONResponse + + if unmarshalErr := json.Unmarshal(body, &jsonResp); unmarshalErr != nil { + return nil, fmt.Errorf("failed to parse JSON response: %w", unmarshalErr) + } + + if len(jsonResp.Data) == 0 { + return nil, fmt.Errorf("no data returned from query") + } + + row := jsonResp.Data[0] + + val, ok := row["val"] + if !ok { + return nil, fmt.Errorf("val not found in response") + } + + t, raw, parseErr := parseTimeValue(val) + if parseErr != nil { + return nil, fmt.Errorf("failed to parse time value: %w", parseErr) + } + + return &singleValueResult{ + Time: t, + Raw: raw, + }, nil +} + +// clickHouseJSONResponse represents ClickHouse JSON format response. +type clickHouseJSONResponse struct { + Data []map[string]any `json:"data"` +} + +// parseTimeValue parses a time value from ClickHouse JSON response. +// It handles both DateTime strings and Unix timestamps. +func parseTimeValue(val any) (time.Time, string, error) { + switch v := val.(type) { + case string: + // Try parsing as DateTime string + t, err := time.Parse("2006-01-02 15:04:05", v) + if err != nil { + // Try with timezone + t, err = time.Parse(time.RFC3339, v) + if err != nil { + return time.Time{}, v, fmt.Errorf("failed to parse time string '%s': %w", v, err) + } + } + + return t, v, nil + + case float64: + // Unix timestamp (JSON numbers are float64) + t := time.Unix(int64(v), 0).UTC() + + return t, fmt.Sprintf("%.0f", v), nil + + case int64: + t := time.Unix(v, 0).UTC() + + return t, fmt.Sprintf("%d", v), nil + + default: + return time.Time{}, fmt.Sprintf("%v", val), fmt.Errorf("unsupported time value type: %T", val) + } +} + +// QueryModelRanges queries ranges for multiple models. +// If overrideColumn is non-empty, it will be used for all models instead of detected columns. +func (g *Generator) QueryModelRanges(ctx context.Context, models []string, network string, rangeInfos map[string]*RangeColumnInfo, overrideColumn string) ([]*ModelRange, error) { + ranges := make([]*ModelRange, 0, len(models)) + + for _, model := range models { + rangeCol := DefaultRangeColumn + + // Use override if provided, otherwise use detected column + if overrideColumn != "" { + rangeCol = overrideColumn + } else if info, ok := rangeInfos[model]; ok { + rangeCol = info.RangeColumn + } + + modelRange, err := g.QueryModelRange(ctx, model, network, rangeCol) + if err != nil { + return nil, fmt.Errorf("failed to query range for %s: %w", model, err) + } + + ranges = append(ranges, modelRange) + } + + return ranges, nil +} + +// FindIntersection finds the overlapping range across all model ranges. +// Returns nil if there is no intersection. +func FindIntersection(ranges []*ModelRange) (*ModelRange, error) { + if len(ranges) == 0 { + return nil, fmt.Errorf("no ranges provided") + } + + if len(ranges) == 1 { + return ranges[0], nil + } + + // Find the maximum of all minimums and minimum of all maximums + maxMin := ranges[0].Min + minMax := ranges[0].Max + + for _, r := range ranges[1:] { + if r.Min.After(maxMin) { + maxMin = r.Min + } + + if r.Max.Before(minMax) { + minMax = r.Max + } + } + + // Check if there's an intersection + if maxMin.After(minMax) { + return nil, fmt.Errorf("no intersecting range found: ranges do not overlap") + } + + return &ModelRange{ + Model: "intersection", + RangeColumn: ranges[0].RangeColumn, + Min: maxMin, + Max: minMax, + MinRaw: maxMin.Format("2006-01-02 15:04:05"), + MaxRaw: minMax.Format("2006-01-02 15:04:05"), + }, nil +} + +// FormatRange returns a human-readable string representation of the range. +func (r *ModelRange) FormatRange() string { + return fmt.Sprintf("%s to %s", + r.Min.Format("2006-01-02 15:04:05"), + r.Max.Format("2006-01-02 15:04:05")) +} diff --git a/pkg/seeddata/s3.go b/pkg/seeddata/s3.go new file mode 100644 index 0000000..3018cad --- /dev/null +++ b/pkg/seeddata/s3.go @@ -0,0 +1,251 @@ +package seeddata + +import ( + "context" + "errors" + "fmt" + "os" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/s3/types" + awshttp "github.com/aws/smithy-go/transport/http" + "github.com/sirupsen/logrus" +) + +const ( + // DefaultS3Bucket is the default S3 bucket for xatu-cbt seed data. + DefaultS3Bucket = "ethpandaops-platform-production-public" + + // DefaultS3PublicDomain is the public domain for the S3 bucket. + DefaultS3PublicDomain = "data.ethpandaops.io" + + // DefaultS3Prefix is the default path prefix in the S3 bucket. + DefaultS3Prefix = "xatu-cbt" + + // DefaultS3Region is the default region (required by SDK, but endpoint controls routing). + DefaultS3Region = "us-east-1" + + // DefaultS3Endpoint is the default S3 endpoint (Cloudflare R2). + DefaultS3Endpoint = "https://539bc53131934672bf85e7260ec0b218.r2.cloudflarestorage.com" + + // EnvS3Endpoint is the environment variable for custom S3 endpoint. + EnvS3Endpoint = "S3_ENDPOINT" + + // EnvS3Bucket is the environment variable for custom S3 bucket name. + EnvS3Bucket = "S3_BUCKET" +) + +// S3Uploader handles uploading parquet files to S3. +type S3Uploader struct { + log logrus.FieldLogger + client *s3.Client + bucket string + publicDomain string + prefix string +} + +// NewS3Uploader creates a new S3 uploader. +// It reads AWS credentials from environment variables or AWS_PROFILE. +// For S3-compatible services (DigitalOcean Spaces, MinIO, etc.), set S3_ENDPOINT. +// To use a custom bucket, set S3_BUCKET. +func NewS3Uploader(ctx context.Context, log logrus.FieldLogger) (*S3Uploader, error) { + // Load AWS config - always set region (required by SDK even with custom endpoint) + cfg, err := config.LoadDefaultConfig(ctx, + config.WithRegion(DefaultS3Region), + ) + if err != nil { + return nil, fmt.Errorf("failed to load AWS config: %w", err) + } + + // Use custom endpoint or default to DigitalOcean Spaces AMS3 + endpoint := os.Getenv(EnvS3Endpoint) + if endpoint == "" { + endpoint = DefaultS3Endpoint + } + + log.WithField("endpoint", endpoint).Debug("using S3 endpoint") + + client := s3.NewFromConfig(cfg, func(o *s3.Options) { + o.BaseEndpoint = &endpoint + o.UsePathStyle = true // Required for most S3-compatible services + }) + + // Check for custom bucket + bucket := DefaultS3Bucket + if customBucket := os.Getenv(EnvS3Bucket); customBucket != "" { + bucket = customBucket + } + + return &S3Uploader{ + log: log.WithField("component", "s3uploader"), + client: client, + bucket: bucket, + publicDomain: DefaultS3PublicDomain, + prefix: DefaultS3Prefix, + }, nil +} + +// UploadOptions contains options for uploading to S3. +type UploadOptions struct { + LocalPath string // Path to local file + Network string // Network name (e.g., "mainnet", "sepolia") + Spec string // Fork spec (e.g., "pectra", "fusaka") + Model string // Model name (e.g., "beacon_api_eth_v1_events_block") + Filename string // Custom filename (without .parquet extension, defaults to Model) +} + +// UploadResult contains the result of an S3 upload. +type UploadResult struct { + S3URL string // S3 URL (s3://bucket/path) + PublicURL string // Public HTTPS URL +} + +// Upload uploads a parquet file to S3. +func (u *S3Uploader) Upload(ctx context.Context, opts UploadOptions) (*UploadResult, error) { + // Use custom filename or default to model name + filename := opts.Filename + if filename == "" { + filename = opts.Model + } + + // Build S3 key + key := fmt.Sprintf("%s/%s/%s/%s.parquet", u.prefix, opts.Network, opts.Spec, filename) + + u.log.WithFields(logrus.Fields{ + "bucket": u.bucket, + "key": key, + "file": opts.LocalPath, + }).Debug("uploading to S3") + + // Open local file + file, err := os.Open(opts.LocalPath) + if err != nil { + return nil, fmt.Errorf("failed to open file: %w", err) + } + defer file.Close() + + // Get file size for explicit ContentLength + fileInfo, err := file.Stat() + if err != nil { + return nil, fmt.Errorf("failed to stat file: %w", err) + } + + fileSize := fileInfo.Size() + + u.log.WithFields(logrus.Fields{ + "file": opts.LocalPath, + "size": fileSize, + "key": key, + }).Debug("uploading file to S3") + + // Upload to S3 with explicit content length + // Cache-Control: no-cache ensures CDN/browsers always revalidate with origin + // R2 will use ETag for efficient conditional requests (304 Not Modified) + _, err = u.client.PutObject(ctx, &s3.PutObjectInput{ + Bucket: aws.String(u.bucket), + Key: aws.String(key), + Body: file, + ContentType: aws.String("application/octet-stream"), + ContentLength: aws.Int64(fileSize), + CacheControl: aws.String("no-cache"), + }) + if err != nil { + return nil, fmt.Errorf("failed to upload to S3: %w", err) + } + + // Verify upload by checking object metadata + headResp, headErr := u.client.HeadObject(ctx, &s3.HeadObjectInput{ + Bucket: aws.String(u.bucket), + Key: aws.String(key), + }) + if headErr != nil { + u.log.WithError(headErr).Warn("failed to verify upload") + } else if headResp.ContentLength != nil && *headResp.ContentLength != fileSize { + return nil, fmt.Errorf("upload verification failed: expected %d bytes but S3 reports %d bytes", + fileSize, *headResp.ContentLength) + } else { + u.log.WithField("verified_size", *headResp.ContentLength).Debug("upload verified") + } + + return &UploadResult{ + S3URL: fmt.Sprintf("s3://%s/%s", u.bucket, key), + PublicURL: fmt.Sprintf("https://%s/%s", u.publicDomain, key), + }, nil +} + +// SetBucket sets a custom S3 bucket (for testing or custom destinations). +func (u *S3Uploader) SetBucket(bucket string) { + u.bucket = bucket +} + +// SetPrefix sets a custom S3 prefix (for testing or custom destinations). +func (u *S3Uploader) SetPrefix(prefix string) { + u.prefix = prefix +} + +// ObjectExists checks if an object already exists at the given path. +// Returns true if the object exists, false otherwise. +func (u *S3Uploader) ObjectExists(ctx context.Context, network, spec, filename string) (bool, error) { + key := fmt.Sprintf("%s/%s/%s/%s.parquet", u.prefix, network, spec, filename) + + _, err := u.client.HeadObject(ctx, &s3.HeadObjectInput{ + Bucket: aws.String(u.bucket), + Key: aws.String(key), + }) + if err != nil { + // Check if it's a "not found" error + var notFound *types.NotFound + if errors.As(err, ¬Found) { + return false, nil + } + + // Check for NoSuchKey error (some S3-compatible services use this) + var noSuchKey *types.NoSuchKey + if errors.As(err, &noSuchKey) { + return false, nil + } + + // For other errors, check if it's a 404 status code + var respErr *awshttp.ResponseError + if errors.As(err, &respErr) && respErr.HTTPStatusCode() == 404 { + return false, nil + } + + return false, fmt.Errorf("failed to check object existence: %w", err) + } + + return true, nil +} + +// GetPublicURL returns the public URL for an object without uploading. +func (u *S3Uploader) GetPublicURL(network, spec, filename string) string { + key := fmt.Sprintf("%s/%s/%s/%s.parquet", u.prefix, network, spec, filename) + + return fmt.Sprintf("https://%s/%s", u.publicDomain, key) +} + +// CheckAccess verifies the uploader has write access to the S3 bucket. +// It attempts to list objects at the prefix to verify credentials and permissions. +func (u *S3Uploader) CheckAccess(ctx context.Context) error { + u.log.WithFields(logrus.Fields{ + "bucket": u.bucket, + "prefix": u.prefix, + }).Debug("checking S3 access") + + // Try to list objects at the prefix - this verifies: + // 1. AWS credentials are valid + // 2. User has at least read access to the bucket + // Note: This doesn't guarantee write access, but catches most common issues + _, err := u.client.ListObjectsV2(ctx, &s3.ListObjectsV2Input{ + Bucket: aws.String(u.bucket), + Prefix: aws.String(u.prefix + "/"), + MaxKeys: aws.Int32(1), + }) + if err != nil { + return fmt.Errorf("S3 access check failed (bucket: %s): %w", u.bucket, err) + } + + return nil +} diff --git a/pkg/seeddata/sanitize.go b/pkg/seeddata/sanitize.go new file mode 100644 index 0000000..e3bbdfb --- /dev/null +++ b/pkg/seeddata/sanitize.go @@ -0,0 +1,250 @@ +package seeddata + +import ( + "context" + "crypto/rand" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +const ( + // saltLength is the number of random bytes used for salt generation. + saltLength = 32 +) + +// GenerateSalt creates a cryptographically random salt for IP sanitization. +// The salt should be generated once per seed data generation run and shared +// across all models to ensure consistent IP anonymization. +func GenerateSalt() (string, error) { + bytes := make([]byte, saltLength) + + if _, err := rand.Read(bytes); err != nil { + return "", fmt.Errorf("failed to generate random salt: %w", err) + } + + return hex.EncodeToString(bytes), nil +} + +// ColumnInfo represents a column's name and type from ClickHouse schema. +type ColumnInfo struct { + Name string + Type string +} + +// DescribeTable queries ClickHouse to get the schema for a table. +func (g *Generator) DescribeTable(ctx context.Context, model string) ([]ColumnInfo, error) { + query := fmt.Sprintf("DESCRIBE TABLE default.%s FORMAT JSON", model) + + chURL, err := g.buildClickHouseHTTPURL() + if err != nil { + return nil, fmt.Errorf("failed to build ClickHouse URL: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, chURL, strings.NewReader(query)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "text/plain") + + client := &http.Client{ + Timeout: 30 * time.Second, + } + + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to execute request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + + return nil, fmt.Errorf("ClickHouse returned status %d: %s", resp.StatusCode, string(body)) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + var descResp describeTableResponse + if unmarshalErr := json.Unmarshal(body, &descResp); unmarshalErr != nil { + return nil, fmt.Errorf("failed to parse JSON response: %w", unmarshalErr) + } + + columns := make([]ColumnInfo, 0, len(descResp.Data)) + + for _, row := range descResp.Data { + columns = append(columns, ColumnInfo(row)) + } + + return columns, nil +} + +// describeTableResponse represents the ClickHouse DESCRIBE TABLE JSON response. +type describeTableResponse struct { + Data []describeTableRow `json:"data"` +} + +// describeTableRow represents a single row from DESCRIBE TABLE. +type describeTableRow struct { + Name string `json:"name"` + Type string `json:"type"` +} + +// IsIPColumn checks if a column type is an IP address type (IPv4 or IPv6). +// Handles both direct types and Nullable wrappers. +func IsIPColumn(colType string) bool { + // Normalize the type string for comparison + normalized := strings.TrimSpace(colType) + + // Direct IP types + if normalized == "IPv4" || normalized == "IPv6" { + return true + } + + // Nullable IP types + if normalized == "Nullable(IPv4)" || normalized == "Nullable(IPv6)" { + return true + } + + return false +} + +// IsIPv4Column checks if a column is specifically IPv4 type. +func IsIPv4Column(colType string) bool { + normalized := strings.TrimSpace(colType) + + return normalized == "IPv4" || normalized == "Nullable(IPv4)" +} + +// IsNullableIPColumn checks if an IP column is nullable. +func IsNullableIPColumn(colType string) bool { + return strings.HasPrefix(strings.TrimSpace(colType), "Nullable(") +} + +// BuildSanitizedColumnExpr builds a SQL expression for a column with IP sanitization if needed. +// Preserves IP types: IPv4 → IPv4, IPv6 → IPv6, IPv4-mapped-IPv6 → IPv4-mapped-IPv6. +func BuildSanitizedColumnExpr(col ColumnInfo, salt string) string { + if !IsIPColumn(col.Type) { + // Non-IP column: select as-is + return col.Name + } + + // Escape the salt for SQL (single quotes doubled) + escapedSalt := strings.ReplaceAll(salt, "'", "''") + + // IPv4 columns: straightforward hash to IPv4 + if IsIPv4Column(col.Type) { + if IsNullableIPColumn(col.Type) { + // Nullable(IPv4): preserve NULL, hash non-NULL to IPv4 + return fmt.Sprintf( + "if(%s IS NOT NULL, toIPv4(reinterpret(substring(sipHash128(%s, '%s'), 1, 4), 'UInt32')), NULL) AS %s", + col.Name, col.Name, escapedSalt, col.Name, + ) + } + + // Non-nullable IPv4: hash to IPv4 + return fmt.Sprintf( + "toIPv4(reinterpret(substring(sipHash128(%s, '%s'), 1, 4), 'UInt32')) AS %s", + col.Name, escapedSalt, col.Name, + ) + } + + // IPv6 columns: detect IPv4-mapped addresses and preserve their format + // IPv4-mapped addresses start with '::ffff:' when converted to string + if IsNullableIPColumn(col.Type) { + // Nullable(IPv6): preserve NULL, detect IPv4-mapped vs native IPv6 + return fmt.Sprintf( + "if(%s IS NOT NULL, "+ + "if(startsWith(IPv6NumToString(%s), '::ffff:'), "+ + "IPv4ToIPv6(toIPv4(reinterpret(substring(sipHash128(%s, '%s'), 1, 4), 'UInt32'))), "+ + "CAST(reinterpret(sipHash128(%s, '%s'), 'FixedString(16)') AS IPv6)), "+ + "NULL) AS %s", + col.Name, col.Name, col.Name, escapedSalt, col.Name, escapedSalt, col.Name, + ) + } + + // Non-nullable IPv6: detect IPv4-mapped vs native IPv6 + return fmt.Sprintf( + "if(startsWith(IPv6NumToString(%s), '::ffff:'), "+ + "IPv4ToIPv6(toIPv4(reinterpret(substring(sipHash128(%s, '%s'), 1, 4), 'UInt32'))), "+ + "CAST(reinterpret(sipHash128(%s, '%s'), 'FixedString(16)') AS IPv6)) AS %s", + col.Name, col.Name, escapedSalt, col.Name, escapedSalt, col.Name, + ) +} + +// SanitizedColumnResult contains the result of building a sanitized column list. +type SanitizedColumnResult struct { + ColumnExpr string // Comma-separated column expressions for SELECT + SanitizedColumns []string // Names of columns that were sanitized (for display) +} + +// BuildSanitizedColumnList builds a complete SELECT column list with IP sanitization. +// Returns the column expressions and a list of which columns were sanitized. +func (g *Generator) BuildSanitizedColumnList(ctx context.Context, model, salt string) (*SanitizedColumnResult, error) { + columns, err := g.DescribeTable(ctx, model) + if err != nil { + return nil, fmt.Errorf("failed to describe table %s: %w", model, err) + } + + if len(columns) == 0 { + return nil, fmt.Errorf("table %s has no columns", model) + } + + // Find IP columns for reporting + sanitizedCols := make([]string, 0) + + for _, col := range columns { + if IsIPColumn(col.Type) { + sanitizedCols = append(sanitizedCols, fmt.Sprintf("%s (%s)", col.Name, col.Type)) + } + } + + // Build column expressions + exprs := make([]string, 0, len(columns)) + + for _, col := range columns { + expr := BuildSanitizedColumnExpr(col, salt) + exprs = append(exprs, expr) + } + + return &SanitizedColumnResult{ + ColumnExpr: strings.Join(exprs, ", "), + SanitizedColumns: sanitizedCols, + }, nil +} + +// CountIPColumns counts the number of IP columns in a table schema. +// Useful for logging/debugging. +func CountIPColumns(columns []ColumnInfo) int { + count := 0 + + for _, col := range columns { + if IsIPColumn(col.Type) { + count++ + } + } + + return count +} + +// GetIPColumnNames returns the names of all IP columns in a schema. +// Useful for logging/debugging. +func GetIPColumnNames(columns []ColumnInfo) []string { + names := make([]string, 0) + + for _, col := range columns { + if IsIPColumn(col.Type) { + names = append(names, col.Name) + } + } + + return names +} diff --git a/pkg/seeddata/template.go b/pkg/seeddata/template.go new file mode 100644 index 0000000..aedff4d --- /dev/null +++ b/pkg/seeddata/template.go @@ -0,0 +1,121 @@ +package seeddata + +import ( + "bytes" + "fmt" + "sort" + "text/template" + + "gopkg.in/yaml.v3" +) + +// TestYAMLTemplate is the template for generating xatu-cbt test YAML files. +const TestYAMLTemplate = `model: {{ .Model }} +network: {{ .Network }} +spec: {{ .Spec }} +external_data: + {{ .Model }}: + url: {{ .URL }} + network_column: meta_network_name +assertions: + - name: total count + sql: | + SELECT COUNT(*) AS count + FROM cluster('{remote_cluster}', default.{{ .Model }}) + expected: + count: {{ .RowCount }} +` + +// TemplateData contains the data for generating a test YAML template. +type TemplateData struct { + Model string // Model/table name + Network string // Network name (e.g., "mainnet", "sepolia") + Spec string // Fork spec (e.g., "pectra", "fusaka") + URL string // URL to the parquet file + RowCount int64 // Number of rows in the parquet file +} + +// GenerateTestYAML generates a test YAML string from the template data. +func GenerateTestYAML(data TemplateData) (string, error) { + tmpl, err := template.New("test").Parse(TestYAMLTemplate) + if err != nil { + return "", fmt.Errorf("failed to parse template: %w", err) + } + + var buf bytes.Buffer + + if execErr := tmpl.Execute(&buf, data); execErr != nil { + return "", fmt.Errorf("failed to execute template: %w", execErr) + } + + return buf.String(), nil +} + +// ExternalDataEntry represents an external data entry in the test YAML. +type ExternalDataEntry struct { + URL string `yaml:"url"` + NetworkColumn string `yaml:"network_column"` //nolint:tagliatelle // xatu-cbt uses snake_case +} + +// TransformationTestYAML represents the complete test YAML structure. +type TransformationTestYAML struct { + Model string `yaml:"model"` + Network string `yaml:"network"` + Spec string `yaml:"spec"` + ExternalData map[string]ExternalDataEntry `yaml:"external_data"` //nolint:tagliatelle // xatu-cbt uses snake_case + Assertions []Assertion `yaml:"assertions"` +} + +// TransformationTemplateData contains the data for generating transformation test YAML. +type TransformationTemplateData struct { + Model string // Transformation model name + Network string // Network name + Spec string // Fork spec + ExternalModels []string // List of external model names + URLs map[string]string // model name -> parquet URL + Assertions []Assertion // Generated assertions +} + +// GenerateTransformationTestYAML generates a complete test YAML for transformation models. +func GenerateTransformationTestYAML(data TransformationTemplateData) (string, error) { + testYAML := TransformationTestYAML{ + Model: data.Model, + Network: data.Network, + Spec: data.Spec, + ExternalData: make(map[string]ExternalDataEntry, len(data.ExternalModels)), + Assertions: data.Assertions, + } + + // Sort external models for consistent output + sortedModels := make([]string, len(data.ExternalModels)) + copy(sortedModels, data.ExternalModels) + sort.Strings(sortedModels) + + for _, model := range sortedModels { + url, ok := data.URLs[model] + if !ok { + return "", fmt.Errorf("missing URL for external model: %s", model) + } + + testYAML.ExternalData[model] = ExternalDataEntry{ + URL: url, + NetworkColumn: "meta_network_name", + } + } + + // If no assertions provided, use default + if len(testYAML.Assertions) == 0 { + testYAML.Assertions = GetDefaultAssertions(data.Model) + } + + var buf bytes.Buffer + + encoder := yaml.NewEncoder(&buf) + encoder.SetIndent(4) + + if err := encoder.Encode(testYAML); err != nil { + return "", fmt.Errorf("failed to encode YAML: %w", err) + } + + return buf.String(), nil +} diff --git a/pkg/ui/input.go b/pkg/ui/input.go new file mode 100644 index 0000000..375da6f --- /dev/null +++ b/pkg/ui/input.go @@ -0,0 +1,40 @@ +package ui + +import ( + "fmt" + + "github.com/pterm/pterm" +) + +// TextInput displays an interactive text input and returns the entered value. +// Returns the entered text, or error if cancelled/failed. +func TextInput(prompt string, defaultValue string) (string, error) { + input := pterm.DefaultInteractiveTextInput. + WithDefaultText(prompt) + + if defaultValue != "" { + input = input.WithDefaultValue(defaultValue) + } + + result, err := input.Show() + if err != nil { + return "", fmt.Errorf("%w: %w", ErrSelectionCancelled, err) + } + + return result, nil +} + +// TextInputRequired displays an interactive text input that requires a non-empty value. +// Returns the entered text, or error if cancelled/failed/empty. +func TextInputRequired(prompt string) (string, error) { + result, err := TextInput(prompt, "") + if err != nil { + return "", err + } + + if result == "" { + return "", fmt.Errorf("input required") + } + + return result, nil +}