From ccd56b1df2faae7652662411b8b7654fec94cc76 Mon Sep 17 00:00:00 2001 From: Matty Evans Date: Mon, 15 Dec 2025 11:30:48 +1000 Subject: [PATCH] feat: enhance seed-data discovery with dimension table support and SQL filter analysis - add support for entity/dimension tables (no time range) via intervalType - read intermediate transformation SQL to extract WHERE clause filters - extend discovery prompt to include correlation filters for dimension tables - add FilterSQL and CorrelationFilter to TableRangeStrategy for precise filtering - improve fallback discovery to handle entity models and missing ranges - normalize YAML field names and fix unquoted datetime values in Claude responses - extend QueryRowCount and GenerateOptions to accept additional SQL filters - add S3 Cache-Control: no-cache header for fresh seed data downloads --- .../lab_xatu_cbt_generate_transformation.go | 119 +++- pkg/seeddata/assertions.go | 16 +- pkg/seeddata/dependencies.go | 197 +++++- pkg/seeddata/discovery.go | 604 ++++++++++++++---- pkg/seeddata/generator.go | 38 +- pkg/seeddata/s3.go | 3 + 6 files changed, 802 insertions(+), 175 deletions(-) diff --git a/pkg/commands/lab_xatu_cbt_generate_transformation.go b/pkg/commands/lab_xatu_cbt_generate_transformation.go index 7454c9f..dccf32b 100644 --- a/pkg/commands/lab_xatu_cbt_generate_transformation.go +++ b/pkg/commands/lab_xatu_cbt_generate_transformation.go @@ -214,7 +214,7 @@ func runGenerateTransformationTest( return fmt.Errorf("failed to detect range columns: %w", err) } - discoveryResult, err = seeddata.FallbackRangeDiscovery(ctx, gen, externalModels, network, rangeInfos, duration) + discoveryResult, err = seeddata.FallbackRangeDiscovery(ctx, gen, externalModels, network, rangeInfos, duration, labCfg.Repos.XatuCBT) if err != nil { return fmt.Errorf("fallback range discovery failed: %w", err) } @@ -254,6 +254,26 @@ func runGenerateTransformationTest( return fmt.Errorf("failed to read transformation SQL: %w", sqlErr) } + // Read intermediate dependency SQL (for WHERE clause analysis) + intermediateSQL, intErr := seeddata.ReadIntermediateSQL(tree, labCfg.Repos.XatuCBT) + if intErr != nil { + ui.Warning(fmt.Sprintf("Could not read intermediate SQL: %v", intErr)) + // Continue without intermediate SQL - not critical + } + + // Convert to IntermediateSQL slice + var intermediateModels []seeddata.IntermediateSQL + for modelName, sql := range intermediateSQL { + intermediateModels = append(intermediateModels, seeddata.IntermediateSQL{ + Model: modelName, + SQL: sql, + }) + } + + if len(intermediateModels) > 0 { + ui.Info(fmt.Sprintf("Including %d intermediate model(s) for WHERE clause analysis", len(intermediateModels))) + } + // Invoke Claude for analysis ui.Blank() @@ -262,6 +282,7 @@ func runGenerateTransformationTest( discoveryResult, err = discoveryClient.AnalyzeRanges(ctx, seeddata.DiscoveryInput{ TransformationModel: model, TransformationSQL: transformationSQL, + IntermediateModels: intermediateModels, Network: network, Duration: duration, ExternalModels: schemaInfo, @@ -277,7 +298,7 @@ func runGenerateTransformationTest( return fmt.Errorf("failed to detect range columns: %w", rangeErr) } - discoveryResult, err = seeddata.FallbackRangeDiscovery(ctx, gen, externalModels, network, rangeInfos, duration) + discoveryResult, err = seeddata.FallbackRangeDiscovery(ctx, gen, externalModels, network, rangeInfos, duration, labCfg.Repos.XatuCBT) if err != nil { return fmt.Errorf("fallback range discovery failed: %w", err) } @@ -345,14 +366,44 @@ func runGenerateTransformationTest( bridgeInfo = fmt.Sprintf(" (via %s)", strategy.BridgeTable) } - ui.Info(fmt.Sprintf(" • %s: %s [%s → %s] %s%s", - strategy.Model, - strategy.RangeColumn, - strategy.FromValue, - strategy.ToValue, - confidence, - bridgeInfo, - )) + // Handle dimension tables (no range) vs regular tables + if strategy.RangeColumn == "" || strategy.ColumnType == seeddata.RangeColumnTypeNone { + ui.Info(fmt.Sprintf(" • %s: (dimension table - all data) %s%s", + strategy.Model, + confidence, + bridgeInfo, + )) + } else { + ui.Info(fmt.Sprintf(" • %s: %s [%s → %s] %s%s", + strategy.Model, + strategy.RangeColumn, + strategy.FromValue, + strategy.ToValue, + confidence, + bridgeInfo, + )) + } + + // Display additional filters if present + if strategy.FilterSQL != "" { + ui.Info(fmt.Sprintf(" Filter: %s", strategy.FilterSQL)) + } + + // Display correlation filter if present (for dimension tables) + if strategy.CorrelationFilter != "" { + // Truncate long subqueries for display + corrFilter := strategy.CorrelationFilter + if len(corrFilter) > 80 { + corrFilter = corrFilter[:77] + "..." + } + + ui.Info(fmt.Sprintf(" Correlation: %s", corrFilter)) + } + + // Display if optional + if strategy.Optional { + ui.Info(" (optional - LEFT JOIN)") + } } // Display warnings @@ -602,21 +653,47 @@ func runGenerateTransformationTest( } // Show query parameters (helps debug empty parquets) - ui.Info(fmt.Sprintf(" %s: %s [%s → %s]", extModel, strategy.RangeColumn, strategy.FromValue, strategy.ToValue)) + filterInfo := "" + if strategy.FilterSQL != "" { + filterInfo = fmt.Sprintf(" + filter: %s", strategy.FilterSQL) + } + + if strategy.CorrelationFilter != "" { + // Truncate long subqueries for display + corrFilter := strategy.CorrelationFilter + if len(corrFilter) > 60 { + corrFilter = corrFilter[:57] + "..." + } + + filterInfo += fmt.Sprintf(" + correlation: %s", corrFilter) + } + + // Handle dimension tables (no range) vs regular tables + if strategy.RangeColumn == "" || strategy.ColumnType == seeddata.RangeColumnTypeNone { + if strategy.CorrelationFilter != "" { + ui.Info(fmt.Sprintf(" %s: (correlated dimension table)%s", extModel, filterInfo)) + } else { + ui.Info(fmt.Sprintf(" %s: (dimension table - all data)%s", extModel, filterInfo)) + } + } else { + ui.Info(fmt.Sprintf(" %s: %s [%s → %s]%s", extModel, strategy.RangeColumn, strategy.FromValue, strategy.ToValue, filterInfo)) + } genSpinner := ui.NewSpinner(fmt.Sprintf("Generating %s", extModel)) result, genErr := gen.Generate(ctx, seeddata.GenerateOptions{ - Model: extModel, - Network: network, - Spec: spec, - RangeColumn: strategy.RangeColumn, - From: strategy.FromValue, - To: strategy.ToValue, - Limit: limit, - OutputPath: outputPath, - SanitizeIPs: sanitizeIPs, - Salt: salt, + Model: extModel, + Network: network, + Spec: spec, + RangeColumn: strategy.RangeColumn, + From: strategy.FromValue, + To: strategy.ToValue, + FilterSQL: strategy.FilterSQL, + CorrelationFilter: strategy.CorrelationFilter, + Limit: limit, + OutputPath: outputPath, + SanitizeIPs: sanitizeIPs, + Salt: salt, }) if genErr != nil { genSpinner.Fail(fmt.Sprintf("Failed to generate %s", extModel)) diff --git a/pkg/seeddata/assertions.go b/pkg/seeddata/assertions.go index 5554dd4..7f80471 100644 --- a/pkg/seeddata/assertions.go +++ b/pkg/seeddata/assertions.go @@ -220,7 +220,7 @@ func extractYAMLFromResponse(response string) string { return strings.TrimSpace(matches[1]) } - // If no code block, look for content starting with a YAML list + // If no code block, look for content starting with a YAML list or discovery YAML lines := strings.Split(response, "\n") var yamlLines []string @@ -229,13 +229,21 @@ func extractYAMLFromResponse(response string) string { for _, line := range lines { trimmed := strings.TrimSpace(line) - if strings.HasPrefix(trimmed, "- name:") { + + // Start YAML when we see assertion-style or discovery-style start markers + if strings.HasPrefix(trimmed, "- name:") || + strings.HasPrefix(trimmed, "primaryRangeType:") || + strings.HasPrefix(trimmed, "primary_range_type:") { inYAML = true } if inYAML { - // Stop if we hit non-YAML content - if trimmed != "" && !strings.HasPrefix(line, " ") && !strings.HasPrefix(line, "-") && !strings.HasPrefix(line, "\t") { + // Stop if we hit non-YAML content (text that doesn't look like YAML) + if trimmed != "" && + !strings.HasPrefix(line, " ") && + !strings.HasPrefix(line, "-") && + !strings.HasPrefix(line, "\t") && + !strings.Contains(line, ":") { break } diff --git a/pkg/seeddata/dependencies.go b/pkg/seeddata/dependencies.go index 5a41d41..d453539 100644 --- a/pkg/seeddata/dependencies.go +++ b/pkg/seeddata/dependencies.go @@ -40,10 +40,28 @@ type DependencyTree struct { // dependencyPattern matches dependency strings like "{{transformation}}.model_name" or "{{external}}.model_name". var dependencyPattern = regexp.MustCompile(`^\{\{(transformation|external)\}\}\.(.+)$`) +// IntervalType represents the interval type from external model frontmatter. +type IntervalType string + +const ( + // IntervalTypeSlot is slot-based interval (time ranges via slot_start_date_time). + IntervalTypeSlot IntervalType = "slot" + // IntervalTypeBlock is block number based interval. + IntervalTypeBlock IntervalType = "block" + // IntervalTypeEntity is for dimension/reference tables with no time range. + IntervalTypeEntity IntervalType = "entity" +) + +// intervalConfig represents the interval configuration in model frontmatter. +type intervalConfig struct { + Type IntervalType `yaml:"type"` +} + // sqlFrontmatter represents the YAML frontmatter in SQL files. type sqlFrontmatter struct { - Table string `yaml:"table"` - Dependencies []string `yaml:"dependencies"` + Table string `yaml:"table"` + Dependencies []string `yaml:"dependencies"` + Interval intervalConfig `yaml:"interval"` } // ParseDependencies parses the dependencies from a SQL file's YAML frontmatter. @@ -84,16 +102,16 @@ func ResolveDependencyTree(model string, xatuCBTPath string, visited map[string] defer func() { visited[model] = false }() - // First, try to find as transformation model - transformationPath := filepath.Join(xatuCBTPath, "models", "transformations", model+".sql") + // Try to find as transformation model (supports .sql and .yml extensions) + transformationPath := findModelFile(xatuCBTPath, "transformations", model) - if _, err := os.Stat(transformationPath); err == nil { + if transformationPath != "" { return resolveTransformationTree(model, transformationPath, xatuCBTPath, visited) } // If not found as transformation, check if it's an external model - externalPath := filepath.Join(xatuCBTPath, "models", "external", model+".sql") - if _, err := os.Stat(externalPath); err == nil { + externalPath := findModelFile(xatuCBTPath, "external", model) + if externalPath != "" { return &DependencyTree{ Model: model, Type: DependencyTypeExternal, @@ -105,6 +123,20 @@ func ResolveDependencyTree(model string, xatuCBTPath string, visited map[string] return nil, fmt.Errorf("model '%s' not found in transformations or external models", model) } +// findModelFile looks for a model file with supported extensions (.sql, .yml, .yaml). +func findModelFile(xatuCBTPath, folder, model string) string { + extensions := []string{".sql", ".yml", ".yaml"} + + for _, ext := range extensions { + path := filepath.Join(xatuCBTPath, "models", folder, model+ext) + if _, err := os.Stat(path); err == nil { + return path + } + } + + return "" +} + // resolveTransformationTree resolves a transformation model's dependency tree. func resolveTransformationTree(model, sqlPath, xatuCBTPath string, visited map[string]bool) (*DependencyTree, error) { deps, err := ParseDependencies(sqlPath) @@ -191,6 +223,67 @@ func (t *DependencyTree) PrintTree(indent string) string { return sb.String() } +// GetIntermediateDependencies returns all intermediate (transformation) model names from the +// dependency tree, excluding the root model. These are non-leaf nodes that transform external data. +// The result is deduplicated. +func (t *DependencyTree) GetIntermediateDependencies() []string { + seen := make(map[string]bool, 8) + intermediates := make([]string, 0, 8) + + t.collectIntermediateDeps(seen, &intermediates, true) + + return intermediates +} + +// collectIntermediateDeps recursively collects intermediate (transformation) dependencies. +func (t *DependencyTree) collectIntermediateDeps(seen map[string]bool, result *[]string, isRoot bool) { + // Skip external models (leaf nodes) + if t.Type == DependencyTypeExternal { + return + } + + // Add this model if it's not the root and not already seen + if !isRoot && !seen[t.Model] { + seen[t.Model] = true + *result = append(*result, t.Model) + } + + // Recurse into children + for _, child := range t.Children { + child.collectIntermediateDeps(seen, result, false) + } +} + +// ReadIntermediateSQL reads the SQL content for all intermediate dependencies. +// Returns a map of model name to SQL content. +// Note: YAML script models (.yml/.yaml) are skipped as they don't contain SQL to analyze. +func ReadIntermediateSQL(tree *DependencyTree, xatuCBTPath string) (map[string]string, error) { + intermediates := tree.GetIntermediateDependencies() + result := make(map[string]string, len(intermediates)) + + for _, model := range intermediates { + modelPath := findModelFile(xatuCBTPath, "transformations", model) + if modelPath == "" { + // Model file not found, skip + continue + } + + // Only read SQL files - YAML script models don't have SQL to analyze + if !strings.HasSuffix(modelPath, ".sql") { + continue + } + + content, err := os.ReadFile(modelPath) + if err != nil { + return nil, fmt.Errorf("failed to read SQL for %s: %w", model, err) + } + + result[model] = string(content) + } + + return result, nil +} + // ListTransformationModels returns a list of available transformation models from the xatu-cbt repo. func ListTransformationModels(xatuCBTPath string) ([]string, error) { modelsDir := filepath.Join(xatuCBTPath, "models", "transformations") @@ -208,17 +301,50 @@ func ListTransformationModels(xatuCBTPath string) ([]string, error) { } name := entry.Name() - if strings.HasSuffix(name, ".sql") { - // Remove .sql extension to get model name - models = append(models, strings.TrimSuffix(name, ".sql")) + + // Support .sql, .yml, and .yaml extensions + for _, ext := range []string{".sql", ".yml", ".yaml"} { + if strings.HasSuffix(name, ext) { + models = append(models, strings.TrimSuffix(name, ext)) + + break + } } } return models, nil } -// parseFrontmatter extracts and parses the YAML frontmatter from a SQL file. -func parseFrontmatter(sqlPath string) (*sqlFrontmatter, error) { +// parseFrontmatter extracts and parses the YAML frontmatter from a SQL file, +// or parses a pure YAML file (.yml/.yaml) directly. +func parseFrontmatter(modelPath string) (*sqlFrontmatter, error) { + // Check if this is a pure YAML file (not SQL with frontmatter) + if strings.HasSuffix(modelPath, ".yml") || strings.HasSuffix(modelPath, ".yaml") { + return parseYAMLFile(modelPath) + } + + // Parse SQL file with YAML frontmatter + return parseSQLFrontmatter(modelPath) +} + +// parseYAMLFile parses a pure YAML model file (.yml or .yaml). +func parseYAMLFile(yamlPath string) (*sqlFrontmatter, error) { + content, err := os.ReadFile(yamlPath) + if err != nil { + return nil, fmt.Errorf("failed to read file: %w", err) + } + + var fm sqlFrontmatter + + if err := yaml.Unmarshal(content, &fm); err != nil { + return nil, fmt.Errorf("failed to parse YAML file: %w", err) + } + + return &fm, nil +} + +// parseSQLFrontmatter extracts and parses the YAML frontmatter from a SQL file. +func parseSQLFrontmatter(sqlPath string) (*sqlFrontmatter, error) { file, err := os.Open(sqlPath) if err != nil { return nil, fmt.Errorf("failed to open file: %w", err) @@ -269,6 +395,53 @@ func parseFrontmatter(sqlPath string) (*sqlFrontmatter, error) { return &fm, nil } +// GetExternalModelIntervalType returns the interval type for an external model. +// Returns IntervalTypeEntity for dimension tables, or the actual type (slot, block, etc.). +func GetExternalModelIntervalType(model, xatuCBTPath string) (IntervalType, error) { + modelPath := findModelFile(xatuCBTPath, "external", model) + if modelPath == "" { + return "", fmt.Errorf("external model '%s' not found", model) + } + + fm, err := parseFrontmatter(modelPath) + if err != nil { + return "", fmt.Errorf("failed to parse frontmatter: %w", err) + } + + if fm.Interval.Type == "" { + // Default to slot if not specified + return IntervalTypeSlot, nil + } + + return fm.Interval.Type, nil +} + +// IsEntityModel checks if an external model is an entity/dimension table. +func IsEntityModel(model, xatuCBTPath string) bool { + intervalType, err := GetExternalModelIntervalType(model, xatuCBTPath) + if err != nil { + return false + } + + return intervalType == IntervalTypeEntity +} + +// GetExternalModelIntervalTypes returns interval types for multiple external models. +func GetExternalModelIntervalTypes(models []string, xatuCBTPath string) (map[string]IntervalType, error) { + result := make(map[string]IntervalType, len(models)) + + for _, model := range models { + intervalType, err := GetExternalModelIntervalType(model, xatuCBTPath) + if err != nil { + return nil, fmt.Errorf("failed to get interval type for %s: %w", model, err) + } + + result[model] = intervalType + } + + return result, nil +} + // parseDependencyString parses a dependency string like "{{transformation}}.model_name". func parseDependencyString(depStr string) (Dependency, error) { matches := dependencyPattern.FindStringSubmatch(depStr) diff --git a/pkg/seeddata/discovery.go b/pkg/seeddata/discovery.go index 8619d91..cd5e129 100644 --- a/pkg/seeddata/discovery.go +++ b/pkg/seeddata/discovery.go @@ -10,6 +10,7 @@ import ( "os" "os/exec" "path/filepath" + "regexp" "strings" "time" @@ -29,23 +30,27 @@ const ( RangeColumnTypeSlot RangeColumnType = "slot" // RangeColumnTypeEpoch represents epoch number columns. RangeColumnTypeEpoch RangeColumnType = "epoch" + // RangeColumnTypeNone represents dimension/reference tables with no time-based range. + RangeColumnTypeNone RangeColumnType = "none" // RangeColumnTypeUnknown represents an unclassified column type. RangeColumnTypeUnknown RangeColumnType = "unknown" ) // TableRangeStrategy describes how to filter a single table for seed data extraction. type TableRangeStrategy struct { - Model string `yaml:"model"` - RangeColumn string `yaml:"rangeColumn"` - ColumnType RangeColumnType `yaml:"columnType"` - FromValue string `yaml:"fromValue"` - ToValue string `yaml:"toValue"` - FilterSQL string `yaml:"filterSql,omitempty"` - RequiresBridge bool `yaml:"requiresBridge"` - BridgeTable string `yaml:"bridgeTable,omitempty"` - BridgeJoinSQL string `yaml:"bridgeJoinSql,omitempty"` - Confidence float64 `yaml:"confidence"` - Reasoning string `yaml:"reasoning"` + Model string `yaml:"model"` + RangeColumn string `yaml:"rangeColumn"` + ColumnType RangeColumnType `yaml:"columnType"` + FromValue string `yaml:"fromValue"` + ToValue string `yaml:"toValue"` + FilterSQL string `yaml:"filterSql,omitempty"` + CorrelationFilter string `yaml:"correlationFilter,omitempty"` // Subquery filter for dimension tables + Optional bool `yaml:"optional,omitempty"` // True if table is optional (LEFT JOIN) + RequiresBridge bool `yaml:"requiresBridge"` + BridgeTable string `yaml:"bridgeTable,omitempty"` + BridgeJoinSQL string `yaml:"bridgeJoinSql,omitempty"` + Confidence float64 `yaml:"confidence"` + Reasoning string `yaml:"reasoning"` } // DiscoveryResult contains the complete AI-generated range strategy. @@ -77,10 +82,11 @@ func (d *DiscoveryResult) GetStrategy(model string) *TableRangeStrategy { // TableSchemaInfo contains schema information for a table. type TableSchemaInfo struct { - Model string `yaml:"model"` - Columns []ColumnInfo `yaml:"columns"` - SampleData []map[string]any `yaml:"sampleData,omitempty"` - RangeInfo *DetectedRange `yaml:"rangeInfo,omitempty"` + Model string `yaml:"model"` + IntervalType IntervalType `yaml:"intervalType,omitempty"` // From model frontmatter (slot, block, entity) + Columns []ColumnInfo `yaml:"columns"` + SampleData []map[string]any `yaml:"sampleData,omitempty"` + RangeInfo *DetectedRange `yaml:"rangeInfo,omitempty"` } // DetectedRange contains detected range information. @@ -96,11 +102,18 @@ type DetectedRange struct { type DiscoveryInput struct { TransformationModel string `yaml:"transformationModel"` TransformationSQL string `yaml:"transformationSql"` + IntermediateModels []IntermediateSQL `yaml:"intermediateModels,omitempty"` // SQL for intermediate deps Network string `yaml:"network"` Duration string `yaml:"duration"` // e.g., "5m", "10m", "1h" ExternalModels []TableSchemaInfo `yaml:"externalModels"` } +// IntermediateSQL contains SQL for an intermediate transformation model. +type IntermediateSQL struct { + Model string `yaml:"model"` + SQL string `yaml:"sql"` +} + // ClaudeDiscoveryClient handles AI-assisted range discovery. type ClaudeDiscoveryClient struct { log logrus.FieldLogger @@ -139,6 +152,8 @@ func (c *ClaudeDiscoveryClient) IsAvailable() bool { } // GatherSchemaInfo collects schema information for all external models. +// All tables are treated equally - Claude will analyze the schema to determine +// the best filtering strategy for each table. func (c *ClaudeDiscoveryClient) GatherSchemaInfo( ctx context.Context, models []string, @@ -150,18 +165,37 @@ func (c *ClaudeDiscoveryClient) GatherSchemaInfo( for _, model := range models { c.log.WithField("model", model).Debug("gathering schema info") + // Get interval type from model frontmatter (informational context for Claude) + intervalType, err := GetExternalModelIntervalType(model, xatuCBTPath) + if err != nil { + c.log.WithError(err).WithField("model", model).Warn("failed to get interval type") + + intervalType = IntervalTypeSlot // Default to slot + } + // Get column schema from ClickHouse columns, err := c.gen.DescribeTable(ctx, model) if err != nil { return nil, fmt.Errorf("failed to describe table %s: %w", model, err) } - // Detect range column from SQL file - rangeCol, err := DetectRangeColumnForModel(model, xatuCBTPath) - if err != nil { - c.log.WithError(err).WithField("model", model).Warn("failed to detect range column") + schemaInfo := TableSchemaInfo{ + Model: model, + IntervalType: intervalType, + Columns: columns, + } - rangeCol = DefaultRangeColumn + // Try to detect range column from SQL file + rangeCol, detectErr := DetectRangeColumnForModel(model, xatuCBTPath) + if detectErr != nil { + c.log.WithError(detectErr).WithField("model", model).Debug("failed to detect range column from SQL") + + // For tables without a detected range column, find any time column in schema + // This handles entity tables and other tables without explicit range definitions + rangeCol = findTimeColumnInSchema(columns) + if rangeCol == "" { + rangeCol = DefaultRangeColumn // Last resort fallback + } } // Classify the range column type @@ -170,33 +204,31 @@ func (c *ClaudeDiscoveryClient) GatherSchemaInfo( // Query the range for this model var minVal, maxVal string - modelRange, err := c.gen.QueryModelRange(ctx, model, network, rangeCol) - if err != nil { - c.log.WithError(err).WithField("model", model).Warn("failed to query model range") + modelRange, rangeErr := c.gen.QueryModelRange(ctx, model, network, rangeCol) + if rangeErr != nil { + c.log.WithError(rangeErr).WithField("model", model).Warn("failed to query model range") } else { minVal = modelRange.MinRaw maxVal = modelRange.MaxRaw } // Get sample data (limited to 3 rows for prompt size) - sampleData, err := c.gen.QueryTableSample(ctx, model, network, 3) - if err != nil { - c.log.WithError(err).WithField("model", model).Warn("failed to query sample data") + sampleData, sampleErr := c.gen.QueryTableSample(ctx, model, network, 3) + if sampleErr != nil { + c.log.WithError(sampleErr).WithField("model", model).Warn("failed to query sample data") // Continue without sample data - not critical } - schemas = append(schemas, TableSchemaInfo{ - Model: model, - Columns: columns, - SampleData: sampleData, - RangeInfo: &DetectedRange{ - Column: rangeCol, - ColumnType: colType, - Detected: rangeCol != DefaultRangeColumn, - MinValue: minVal, - MaxValue: maxVal, - }, - }) + schemaInfo.SampleData = sampleData + schemaInfo.RangeInfo = &DetectedRange{ + Column: rangeCol, + ColumnType: colType, + Detected: detectErr == nil, // True if detected from SQL, false if fallback + MinValue: minVal, + MaxValue: maxVal, + } + + schemas = append(schemas, schemaInfo) } return schemas, nil @@ -264,14 +296,34 @@ func (c *ClaudeDiscoveryClient) buildDiscoveryPrompt(input DiscoveryInput) strin sb.WriteString("## Problem\n") sb.WriteString("These tables may use different range column types:\n") sb.WriteString("- Time-based columns (slot_start_date_time - DateTime)\n") - sb.WriteString("- Numeric columns (block_number - UInt64, slot - UInt64)\n\n") + sb.WriteString("- Numeric columns (block_number - UInt64, slot - UInt64)\n") + sb.WriteString("- **Dimension/reference tables** (no time range - static lookup data like validator entities)\n\n") sb.WriteString("You need to determine how to correlate these ranges so we get matching data across all tables.\n\n") + sb.WriteString("**CRITICAL**: The transformation and its intermediate dependencies may have WHERE clauses that filter data.\n") + sb.WriteString("If you extract seed data that doesn't match these filters, the transformation will produce ZERO output rows!\n") + sb.WriteString("You MUST analyze the SQL and include any necessary filters in `filterSql` for each external model.\n\n") + + sb.WriteString("**IMPORTANT**: ALL tables must be filtered to limit data volume. Look at each table's schema to find appropriate filter columns.\n\n") + + sb.WriteString("**DIMENSION/ENTITY TABLES**: For tables that are JOINed as lookups (like validator entities):\n") + sb.WriteString("1. Analyze the JOIN condition in the transformation SQL to find the join key (e.g., validator_index)\n") + sb.WriteString("2. Use `correlationFilter` to filter by values that exist in the primary data tables\n") + sb.WriteString("3. **IMPORTANT**: Use `GLOBAL IN` (not just `IN`) for subqueries - ClickHouse requires this for distributed tables\n") + sb.WriteString("4. Example: If attestations JOIN on validator_index, filter entities to only those validators\n") + sb.WriteString("5. Mark tables as `optional: true` if the transformation can produce output without them (LEFT JOINs)\n") + sb.WriteString("6. If correlation isn't possible, use a reasonable time-based filter on any available DateTime column\n\n") + sb.WriteString("## External Models and Their Schemas\n\n") for _, schema := range input.ExternalModels { sb.WriteString(fmt.Sprintf("### %s\n", schema.Model)) + // Show interval type from frontmatter (informational context for Claude) + if schema.IntervalType != "" { + sb.WriteString(fmt.Sprintf("Interval Type: %s\n", schema.IntervalType)) + } + if schema.RangeInfo != nil { sb.WriteString(fmt.Sprintf("Detected Range Column: %s (type: %s)\n", schema.RangeInfo.Column, schema.RangeInfo.ColumnType)) @@ -306,6 +358,16 @@ func (c *ClaudeDiscoveryClient) buildDiscoveryPrompt(input DiscoveryInput) strin sb.WriteString(input.TransformationSQL) sb.WriteString("\n```\n\n") + // Include intermediate dependency SQL if available + if len(input.IntermediateModels) > 0 { + sb.WriteString("## Intermediate Dependency SQL\n") + sb.WriteString("The transformation depends on these intermediate models. Their WHERE clauses affect which seed data is usable:\n\n") + + for _, intermediate := range input.IntermediateModels { + sb.WriteString(fmt.Sprintf("### %s\n```sql\n%s\n```\n\n", intermediate.Model, intermediate.SQL)) + } + } + sb.WriteString("## Instructions\n") sb.WriteString("1. Analyze which tables can share a common range column directly\n") sb.WriteString("2. For tables with different range types, determine if correlation is possible via:\n") @@ -314,37 +376,44 @@ func (c *ClaudeDiscoveryClient) buildDiscoveryPrompt(input DiscoveryInput) strin sb.WriteString(" - Shared columns in the data itself\n") sb.WriteString("3. Recommend a primary range specification (type + column + from/to values)\n") sb.WriteString(fmt.Sprintf("4. Use a %s time range (as requested by the user)\n", input.Duration)) - sb.WriteString("5. For each table, specify exactly how to filter it\n\n") + sb.WriteString("5. For each table, specify exactly how to filter it\n") + sb.WriteString("6. **CRITICAL**: Analyze ALL WHERE clauses in the transformation and intermediate SQL.\n") + sb.WriteString(" For each external model, identify any column filters that must be applied to get usable data.\n") + sb.WriteString(" Include these as `filterSql` - a SQL fragment like \"aggregation_bits = ''\" or \"attesting_validator_index IS NOT NULL\"\n\n") sb.WriteString("## Output Format\n") - sb.WriteString("Output ONLY valid YAML matching this structure:\n\n") + sb.WriteString("Output ONLY valid YAML matching this structure.\n") + sb.WriteString("**CRITICAL**: All datetime values MUST be quoted (e.g., \"2025-01-01 00:00:00\") - unquoted colons break YAML!\n\n") sb.WriteString("```yaml\n") - sb.WriteString("primaryRangeType: time # or block, slot\n") + sb.WriteString("primaryRangeType: time\n") sb.WriteString("primaryRangeColumn: slot_start_date_time\n") - sb.WriteString("fromValue: \"2025-01-01 00:00:00\" # or numeric value as string\n") + sb.WriteString("fromValue: \"2025-01-01 00:00:00\"\n") sb.WriteString("toValue: \"2025-01-01 00:05:00\"\n") sb.WriteString("strategies:\n") - sb.WriteString(" - model: table_name\n") - sb.WriteString(" rangeColumn: column_to_filter_on\n") - sb.WriteString(" columnType: time # or block, slot\n") - sb.WriteString(" fromValue: \"value\"\n") - sb.WriteString(" toValue: \"value\"\n") + sb.WriteString(" - model: beacon_api_eth_v1_events_attestation\n") + sb.WriteString(" rangeColumn: slot_start_date_time\n") + sb.WriteString(" columnType: time\n") + sb.WriteString(" fromValue: \"2025-01-01 00:00:00\"\n") + sb.WriteString(" toValue: \"2025-01-01 00:05:00\"\n") + sb.WriteString(" filterSql: \"aggregation_bits = '' AND attesting_validator_index IS NOT NULL\"\n") sb.WriteString(" requiresBridge: false\n") - sb.WriteString(" confidence: 0.95\n") - sb.WriteString(" reasoning: \"Direct filtering on native range column\"\n") - sb.WriteString(" - model: block_based_table\n") - sb.WriteString(" rangeColumn: block_number\n") - sb.WriteString(" columnType: block\n") - sb.WriteString(" fromValue: \"1000000\"\n") - sb.WriteString(" toValue: \"1000100\"\n") - sb.WriteString(" requiresBridge: true\n") - sb.WriteString(" bridgeTable: canonical_beacon_block\n") - sb.WriteString(" confidence: 0.8\n") - sb.WriteString(" reasoning: \"Need to convert time range to block numbers via beacon chain\"\n") + sb.WriteString(" confidence: 0.9\n") + sb.WriteString(" reasoning: \"Filters from intermediate SQL\"\n") + sb.WriteString(" - model: ethseer_validator_entity\n") + sb.WriteString(" rangeColumn: \"\"\n") + sb.WriteString(" columnType: none\n") + sb.WriteString(" fromValue: \"\"\n") + sb.WriteString(" toValue: \"\"\n") + sb.WriteString(" filterSql: \"\"\n") + sb.WriteString(" correlationFilter: \"validator_index GLOBAL IN (SELECT DISTINCT attesting_validator_index FROM default.beacon_api_eth_v1_events_attestation WHERE slot_start_date_time >= toDateTime('2025-01-01 00:00:00') AND slot_start_date_time <= toDateTime('2025-01-01 00:05:00') AND meta_network_name = 'mainnet')\"\n") + sb.WriteString(" optional: true\n") + sb.WriteString(" requiresBridge: false\n") + sb.WriteString(" confidence: 0.9\n") + sb.WriteString(" reasoning: \"Entity table filtered by correlation - only validators appearing in attestation data\"\n") sb.WriteString("overallConfidence: 0.85\n") - sb.WriteString("summary: \"Using time-based primary range with block correlation via canonical_beacon_block\"\n") + sb.WriteString("summary: \"Time-based primary range with filters from dependencies\"\n") sb.WriteString("warnings:\n") - sb.WriteString(" - \"Some execution tables may have gaps where no beacon block was produced\"\n") + sb.WriteString(" - \"Filters applied to ensure usable seed data\"\n") sb.WriteString("```\n\n") sb.WriteString("IMPORTANT:\n") @@ -352,6 +421,8 @@ func (c *ClaudeDiscoveryClient) buildDiscoveryPrompt(input DiscoveryInput) strin sb.WriteString("- Pick a recent time window (last hour or so) within the intersection of all available ranges\n") sb.WriteString("- For block_number tables, estimate block numbers that correspond to the chosen time window\n") sb.WriteString("- Include ALL external models in the strategies list\n") + sb.WriteString("- **ANALYZE ALL WHERE CLAUSES** in transformation and intermediate SQL - missing filters will cause empty test output!\n") + sb.WriteString("- Include `filterSql` for each model (empty string if no additional filters needed)\n") sb.WriteString("- Output ONLY the YAML, no explanations before or after\n") return sb.String() @@ -362,23 +433,112 @@ func (c *ClaudeDiscoveryClient) parseDiscoveryResponse(response string) (*Discov // Extract YAML from response yamlContent := extractYAMLFromResponse(response) if yamlContent == "" { + // Log the raw response for debugging + c.log.WithField("response_preview", truncateString(response, 500)).Error("no valid YAML found in Claude response") + return nil, fmt.Errorf("no valid YAML found in Claude response") } + // Normalize field names (Claude may output snake_case instead of camelCase) + yamlContent = normalizeDiscoveryYAMLFields(yamlContent) + + c.log.WithField("yaml_preview", truncateString(yamlContent, 300)).Debug("extracted YAML content") + var result DiscoveryResult if err := yaml.Unmarshal([]byte(yamlContent), &result); err != nil { + c.log.WithFields(logrus.Fields{ + "error": err, + "yaml_content": truncateString(yamlContent, 500), + }).Error("failed to parse discovery YAML") + return nil, fmt.Errorf("failed to parse discovery YAML: %w", err) } // Validate result if err := c.validateDiscoveryResult(&result); err != nil { + c.log.WithFields(logrus.Fields{ + "error": err, + "yaml_content": truncateString(yamlContent, 500), + "parsed": result, + }).Error("invalid discovery result") + return nil, fmt.Errorf("invalid discovery result: %w", err) } return &result, nil } +// normalizeDiscoveryYAMLFields converts common snake_case field names to camelCase +// and fixes common YAML formatting issues in Claude's output. +func normalizeDiscoveryYAMLFields(yamlContent string) string { + // Map of snake_case to camelCase field names + replacements := map[string]string{ + "primary_range_type:": "primaryRangeType:", + "primary_range_column:": "primaryRangeColumn:", + "from_value:": "fromValue:", + "to_value:": "toValue:", + "range_column:": "rangeColumn:", + "column_type:": "columnType:", + "filter_sql:": "filterSql:", + "correlation_filter:": "correlationFilter:", + "requires_bridge:": "requiresBridge:", + "bridge_table:": "bridgeTable:", + "bridge_join_sql:": "bridgeJoinSql:", + "overall_confidence:": "overallConfidence:", + } + + result := yamlContent + for snake, camel := range replacements { + result = strings.ReplaceAll(result, snake, camel) + } + + // Fix unquoted datetime values (e.g., "fromValue: 2025-01-01 00:00:00" -> "fromValue: \"2025-01-01 00:00:00\"") + result = fixUnquotedDatetimes(result) + + return result +} + +// fixUnquotedDatetimes finds unquoted datetime values and adds quotes. +// Matches patterns like "fromValue: 2025-01-01 00:00:00" where the datetime is not quoted. +func fixUnquotedDatetimes(yamlContent string) string { + // Pattern: key followed by colon, space, then YYYY-MM-DD HH:MM:SS (not already quoted) + // We need to be careful not to double-quote already quoted values + lines := strings.Split(yamlContent, "\n") + result := make([]string, 0, len(lines)) + + datetimePattern := regexp.MustCompile(`^(\s*\w+:\s*)(\d{4}-\d{2}-\d{2}\s+\d{2}:\d{2}:\d{2})(\s*#.*)?$`) + + for _, line := range lines { + matches := datetimePattern.FindStringSubmatch(line) + if matches != nil { + // Found unquoted datetime - add quotes + prefix := matches[1] + datetime := matches[2] + + suffix := "" + if len(matches) > 3 { + suffix = matches[3] + } + + line = prefix + "\"" + datetime + "\"" + suffix + } + + result = append(result, line) + } + + return strings.Join(result, "\n") +} + +// truncateString truncates a string to maxLen characters, adding "..." if truncated. +func truncateString(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + + return s[:maxLen] + "..." +} + // validateDiscoveryResult checks if the AI result is valid and complete. func (c *ClaudeDiscoveryClient) validateDiscoveryResult(result *DiscoveryResult) error { if result.PrimaryRangeColumn == "" { @@ -398,18 +558,61 @@ func (c *ClaudeDiscoveryClient) validateDiscoveryResult(result *DiscoveryResult) return fmt.Errorf("strategy %d: model is required", i) } - if s.RangeColumn == "" { - return fmt.Errorf("strategy %d (%s): range_column is required", i, s.Model) + // Tables can be filtered by: + // 1. Range column (rangeColumn + fromValue/toValue) + // 2. Correlation filter (subquery) + // 3. None type (dimension table - accepts all or filtered by filterSQL) + hasRangeFilter := s.RangeColumn != "" && s.FromValue != "" && s.ToValue != "" + hasCorrelationFilter := s.CorrelationFilter != "" + isNoneType := s.ColumnType == RangeColumnTypeNone + + // Must have at least one filtering mechanism + if !hasRangeFilter && !hasCorrelationFilter && !isNoneType { + return fmt.Errorf("strategy %d (%s): requires range filter (rangeColumn+from/to), correlationFilter, or columnType: none", i, s.Model) } - if s.FromValue == "" || s.ToValue == "" { - return fmt.Errorf("strategy %d (%s): from_value and to_value are required", i, s.Model) + // If range column is specified, from/to are required + if s.RangeColumn != "" && !hasCorrelationFilter && (s.FromValue == "" || s.ToValue == "") { + return fmt.Errorf("strategy %d (%s): from_value and to_value are required when range_column is set without correlationFilter", i, s.Model) } } return nil } +// findTimeColumnInSchema looks for a time-based column in the schema. +// Prefers columns with "date_time" in the name, falls back to any DateTime column. +func findTimeColumnInSchema(columns []ColumnInfo) string { + // First, look for columns with "date_time" in the name (most common pattern) + for _, col := range columns { + colLower := strings.ToLower(col.Name) + if strings.Contains(colLower, "date_time") { + return col.Name + } + } + + // Fall back to any DateTime column + for _, col := range columns { + typeLower := strings.ToLower(col.Type) + if strings.Contains(typeLower, "datetime") { + return col.Name + } + } + + return "" +} + +// contains checks if a string slice contains a value. +func contains(slice []string, value string) bool { + for _, v := range slice { + if v == value { + return true + } + } + + return false +} + // ClassifyRangeColumn determines the semantic type of a range column based on its name and schema type. func ClassifyRangeColumn(column string, schema []ColumnInfo) RangeColumnType { colLower := strings.ToLower(column) @@ -509,22 +712,45 @@ func (g *Generator) QueryTableSample( return jsonResp.Data, nil } -// FallbackRangeDiscovery provides heuristic-based range discovery when Claude is unavailable. -func FallbackRangeDiscovery( - ctx context.Context, - gen *Generator, +// categorizeModelsByType groups models into time, block, entity, and unknown categories. +// Uses interval types from frontmatter if available, falls back to column-based detection. +func categorizeModelsByType( models []string, - network string, + intervalTypes map[string]IntervalType, rangeInfos map[string]*RangeColumnInfo, - duration string, -) (*DiscoveryResult, error) { - // Group models by range column type - timeModels := make([]string, 0) - blockModels := make([]string, 0) +) (timeModels, blockModels, entityModels, unknownModels []string) { + timeModels = make([]string, 0) + blockModels = make([]string, 0) + entityModels = make([]string, 0) + unknownModels = make([]string, 0) for _, model := range models { + // First check frontmatter interval type (most accurate) + if intervalTypes != nil { + if intervalType, ok := intervalTypes[model]; ok { + switch intervalType { + case IntervalTypeEntity: + // Entity models need special handling - they have time columns but indexed by entity + entityModels = append(entityModels, model) + + continue + case IntervalTypeBlock: + blockModels = append(blockModels, model) + + continue + case IntervalTypeSlot: + timeModels = append(timeModels, model) + + continue + } + } + } + + // Fallback: use range column detection info, ok := rangeInfos[model] if !ok { + unknownModels = append(unknownModels, model) + continue } @@ -535,38 +761,112 @@ func FallbackRangeDiscovery( timeModels = append(timeModels, model) case strings.Contains(colLower, "block"): blockModels = append(blockModels, model) + default: + unknownModels = append(unknownModels, model) } } - // Determine primary range type based on majority - var primaryType RangeColumnType + return timeModels, blockModels, entityModels, unknownModels +} - var primaryColumn string +// FallbackRangeDiscovery provides heuristic-based range discovery when Claude is unavailable. +func FallbackRangeDiscovery( + ctx context.Context, + gen *Generator, + models []string, + network string, + rangeInfos map[string]*RangeColumnInfo, + duration string, + xatuCBTPath string, +) (*DiscoveryResult, error) { + // Get interval types from model frontmatter for accurate categorization + intervalTypes, err := GetExternalModelIntervalTypes(models, xatuCBTPath) + if err != nil { + // Log warning but continue with column-based detection as fallback + gen.log.WithError(err).Warn("failed to get interval types from frontmatter, using column-based detection") - if len(timeModels) >= len(blockModels) { - primaryType = RangeColumnTypeTime - primaryColumn = DefaultRangeColumn - } else { - primaryType = RangeColumnTypeBlock - primaryColumn = "block_number" + intervalTypes = nil } - // Query ranges for all models + // Group models by interval type + timeModels, blockModels, entityModels, unknownModels := categorizeModelsByType(models, intervalTypes, rangeInfos) + + // Query ranges for models var latestMin, earliestMax time.Time strategies := make([]TableRangeStrategy, 0, len(models)) for _, model := range models { info := rangeInfos[model] - rangeCol := DefaultRangeColumn - if info != nil { - rangeCol = info.RangeColumn + // Check model category + isEntity := contains(entityModels, model) + isUnknown := contains(unknownModels, model) + + var rangeCol string + + var colType RangeColumnType + + // For entity models, find a time column in the schema + if isEntity { + // Query schema to find a time column + columns, schemaErr := gen.DescribeTable(ctx, model) + if schemaErr != nil { + gen.log.WithError(schemaErr).WithField("model", model).Warn("failed to describe entity table") + + strategies = append(strategies, TableRangeStrategy{ + Model: model, + ColumnType: RangeColumnTypeTime, + Confidence: 0.3, + Reasoning: fmt.Sprintf("Entity model (schema query failed: %v)", schemaErr), + }) + + continue + } + + timeCol := findTimeColumnInSchema(columns) + if timeCol == "" { + gen.log.WithField("model", model).Warn("no time column found for entity model") + + strategies = append(strategies, TableRangeStrategy{ + Model: model, + ColumnType: RangeColumnTypeTime, + Confidence: 0.3, + Reasoning: "Entity model - no time column found in schema", + }) + + continue + } + + rangeCol = timeCol + colType = RangeColumnTypeTime + } else if isUnknown { + // Unknown models - try default range column + rangeCol = DefaultRangeColumn + colType = RangeColumnTypeTime + } else { + // Time or block models - use detected range column + rangeCol = DefaultRangeColumn + if info != nil { + rangeCol = info.RangeColumn + } + + colType = ClassifyRangeColumn(rangeCol, nil) } - modelRange, err := gen.QueryModelRange(ctx, model, network, rangeCol) - if err != nil { - return nil, fmt.Errorf("failed to query range for %s: %w", model, err) + modelRange, queryErr := gen.QueryModelRange(ctx, model, network, rangeCol) + if queryErr != nil { + gen.log.WithError(queryErr).WithField("model", model).Warn("range query failed") + + strategies = append(strategies, TableRangeStrategy{ + Model: model, + RangeColumn: rangeCol, + ColumnType: colType, + Confidence: 0.3, + Reasoning: fmt.Sprintf("Range query failed: %v", queryErr), + }) + + continue } if latestMin.IsZero() || modelRange.Min.After(latestMin) { @@ -577,43 +877,71 @@ func FallbackRangeDiscovery( earliestMax = modelRange.Max } - colType := ClassifyRangeColumn(rangeCol, nil) + reasoning := "Heuristic-based detection (Claude unavailable)" + if isEntity { + reasoning = fmt.Sprintf("Entity model - using %s for time filtering", rangeCol) + } strategies = append(strategies, TableRangeStrategy{ Model: model, RangeColumn: rangeCol, ColumnType: colType, - Confidence: 0.7, // Lower confidence for heuristic - Reasoning: "Heuristic-based detection (Claude unavailable)", + Confidence: 0.7, + Reasoning: reasoning, }) } - // Check for intersection - if latestMin.After(earliestMax) { - return nil, fmt.Errorf("no intersecting range found across all models") - } + // Handle case where no models have valid ranges + hasRanges := !latestMin.IsZero() && !earliestMax.IsZero() - // Parse duration string (e.g., "5m", "10m", "1h") - rangeDuration, parseErr := time.ParseDuration(duration) - if parseErr != nil { - rangeDuration = 5 * time.Minute // Default to 5 minutes if parsing fails - } + var fromValue, toValue string - // Use the last N minutes/hours of available data - effectiveMax := earliestMax.Add(-1 * time.Minute) // Account for ingestion lag - effectiveMin := effectiveMax.Add(-rangeDuration) + var primaryType RangeColumnType - if effectiveMin.Before(latestMin) { - effectiveMin = latestMin - } + var primaryColumn string - fromValue := effectiveMin.Format("2006-01-02 15:04:05") - toValue := effectiveMax.Format("2006-01-02 15:04:05") + if hasRanges { + // Check for intersection + if latestMin.After(earliestMax) { + return nil, fmt.Errorf("no intersecting range found across all models") + } + + // Parse duration string (e.g., "5m", "10m", "1h") + rangeDuration, parseErr := time.ParseDuration(duration) + if parseErr != nil { + rangeDuration = 5 * time.Minute // Default to 5 minutes if parsing fails + } - // Update strategies with range values - for i := range strategies { - strategies[i].FromValue = fromValue - strategies[i].ToValue = toValue + // Use the last N minutes/hours of available data + effectiveMax := earliestMax.Add(-1 * time.Minute) // Account for ingestion lag + effectiveMin := effectiveMax.Add(-rangeDuration) + + if effectiveMin.Before(latestMin) { + effectiveMin = latestMin + } + + fromValue = effectiveMin.Format("2006-01-02 15:04:05") + toValue = effectiveMax.Format("2006-01-02 15:04:05") + + // Determine primary range type based on majority + if len(timeModels)+len(entityModels) >= len(blockModels) { + primaryType = RangeColumnTypeTime + primaryColumn = DefaultRangeColumn + } else { + primaryType = RangeColumnTypeBlock + primaryColumn = "block_number" + } + + // Update strategies with range values (skip strategies without valid range column) + for i := range strategies { + if strategies[i].RangeColumn != "" { + strategies[i].FromValue = fromValue + strategies[i].ToValue = toValue + } + } + } else { + // No valid ranges found - this is an error condition + return nil, fmt.Errorf("no valid range columns found for any model") } warnings := make([]string, 0) @@ -683,7 +1011,7 @@ func (g *Generator) ValidateStrategyHasData( minRowModel := "" for _, strategy := range result.Strategies { - count, err := g.QueryRowCount(ctx, strategy.Model, network, strategy.RangeColumn, strategy.FromValue, strategy.ToValue) + count, err := g.QueryRowCount(ctx, strategy.Model, network, strategy.RangeColumn, strategy.FromValue, strategy.ToValue, strategy.FilterSQL, strategy.CorrelationFilter) modelCount := ModelDataCount{ Model: strategy.Model, @@ -725,6 +1053,7 @@ func (g *Generator) ValidateStrategyHasData( } // QueryRowCount queries the number of rows in a model for a given range. +// For dimension tables (empty rangeColumn), it counts all rows for the network. func (g *Generator) QueryRowCount( ctx context.Context, model string, @@ -732,30 +1061,53 @@ func (g *Generator) QueryRowCount( rangeColumn string, fromValue string, toValue string, + filterSQL string, + correlationFilter string, ) (int64, error) { - // Determine if this is a numeric or time-based column - isNumeric := !strings.Contains(strings.ToLower(rangeColumn), "date") && - !strings.Contains(strings.ToLower(rangeColumn), "time") + // Build additional filter clause + filterClause := "" + if filterSQL != "" { + filterClause = fmt.Sprintf("\n AND %s", filterSQL) + } + + if correlationFilter != "" { + filterClause += fmt.Sprintf("\n AND %s", correlationFilter) + } var query string - if isNumeric { + + // Handle dimension tables (no range column) + if rangeColumn == "" { query = fmt.Sprintf(` + SELECT COUNT(*) as cnt + FROM default.%s + WHERE meta_network_name = '%s'%s + FORMAT JSON + `, model, network, filterClause) + } else { + // Determine if this is a numeric or time-based column + isNumeric := !strings.Contains(strings.ToLower(rangeColumn), "date") && + !strings.Contains(strings.ToLower(rangeColumn), "time") + + if isNumeric { + query = fmt.Sprintf(` SELECT COUNT(*) as cnt FROM default.%s WHERE meta_network_name = '%s' AND %s >= %s - AND %s <= %s + AND %s <= %s%s FORMAT JSON - `, model, network, rangeColumn, fromValue, rangeColumn, toValue) - } else { - query = fmt.Sprintf(` + `, model, network, rangeColumn, fromValue, rangeColumn, toValue, filterClause) + } else { + query = fmt.Sprintf(` SELECT COUNT(*) as cnt FROM default.%s WHERE meta_network_name = '%s' AND %s >= toDateTime('%s') - AND %s <= toDateTime('%s') + AND %s <= toDateTime('%s')%s FORMAT JSON - `, model, network, rangeColumn, fromValue, rangeColumn, toValue) + `, model, network, rangeColumn, fromValue, rangeColumn, toValue, filterClause) + } } g.log.WithFields(logrus.Fields{ diff --git a/pkg/seeddata/generator.go b/pkg/seeddata/generator.go index 4be0f21..0b57742 100644 --- a/pkg/seeddata/generator.go +++ b/pkg/seeddata/generator.go @@ -37,17 +37,19 @@ func NewGenerator(log logrus.FieldLogger, cfg *config.LabConfig) *Generator { // GenerateOptions contains options for generating seed data. type GenerateOptions struct { - Model string // Table name (e.g., "beacon_api_eth_v1_events_block") - Network string // Network name (e.g., "mainnet", "sepolia") - Spec string // Fork spec (e.g., "pectra", "fusaka") - RangeColumn string // Column to filter on (e.g., "slot", "epoch") - From string // Range start value - To string // Range end value - Filters []Filter // Additional filters - Limit int // Max rows (0 = unlimited) - OutputPath string // Output file path - SanitizeIPs bool // Enable IP address sanitization - Salt string // Salt for IP sanitization (shared across batch for consistency) + Model string // Table name (e.g., "beacon_api_eth_v1_events_block") + Network string // Network name (e.g., "mainnet", "sepolia") + Spec string // Fork spec (e.g., "pectra", "fusaka") + RangeColumn string // Column to filter on (e.g., "slot", "epoch") + From string // Range start value + To string // Range end value + Filters []Filter // Additional filters + FilterSQL string // Raw SQL fragment for additional WHERE conditions (from AI discovery) + CorrelationFilter string // Subquery filter for dimension tables (e.g., "validator_index IN (SELECT ...)") + Limit int // Max rows (0 = unlimited) + OutputPath string // Output file path + SanitizeIPs bool // Enable IP address sanitization + Salt string // Salt for IP sanitization (shared across batch for consistency) // sanitizedColumns is an internal field set by Generate() when SanitizeIPs is true. // It contains the pre-computed column list with IP sanitization expressions. @@ -179,7 +181,7 @@ func (g *Generator) buildQuery(opts GenerateOptions) string { } } - // Add additional filters + // Add additional filters (structured) for _, filter := range opts.Filters { sb.WriteString("\n AND ") sb.WriteString(filter.Column) @@ -189,6 +191,18 @@ func (g *Generator) buildQuery(opts GenerateOptions) string { sb.WriteString(formatSQLValue(filter.Value)) } + // Add raw SQL filter if specified (from AI discovery) + if opts.FilterSQL != "" { + sb.WriteString("\n AND ") + sb.WriteString(opts.FilterSQL) + } + + // Add correlation filter if specified (subquery for dimension tables) + if opts.CorrelationFilter != "" { + sb.WriteString("\n AND ") + sb.WriteString(opts.CorrelationFilter) + } + // Add limit if specified if opts.Limit > 0 { sb.WriteString(fmt.Sprintf("\nLIMIT %d", opts.Limit)) diff --git a/pkg/seeddata/s3.go b/pkg/seeddata/s3.go index 91a5d25..3018cad 100644 --- a/pkg/seeddata/s3.go +++ b/pkg/seeddata/s3.go @@ -141,12 +141,15 @@ func (u *S3Uploader) Upload(ctx context.Context, opts UploadOptions) (*UploadRes }).Debug("uploading file to S3") // Upload to S3 with explicit content length + // Cache-Control: no-cache ensures CDN/browsers always revalidate with origin + // R2 will use ETag for efficient conditional requests (304 Not Modified) _, err = u.client.PutObject(ctx, &s3.PutObjectInput{ Bucket: aws.String(u.bucket), Key: aws.String(key), Body: file, ContentType: aws.String("application/octet-stream"), ContentLength: aws.Int64(fileSize), + CacheControl: aws.String("no-cache"), }) if err != nil { return nil, fmt.Errorf("failed to upload to S3: %w", err)