Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 96 additions & 7 deletions cel2sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,50 @@ import (
// Implementations based on `google/cel-go`'s unparser
// https://github.com/google/cel-go/blob/master/parser/unparser.go

// Convert converts a CEL AST to a PostgreSQL SQL WHERE clause condition.
func Convert(ast *cel.Ast) (string, error) {
return ConvertWithSchemas(ast, nil)
// ConvertOption is a functional option for configuring the Convert function.
type ConvertOption func(*convertOptions)

// convertOptions holds configuration options for the Convert function.
type convertOptions struct {
schemas map[string]pg.Schema
}

// WithSchemas provides schema information for proper JSON/JSONB field handling.
// This option is required for correct SQL generation when using JSON/JSONB fields.
//
// Example:
//
// schemas := provider.GetSchemas()
// sql, err := cel2sql.Convert(ast, cel2sql.WithSchemas(schemas))
func WithSchemas(schemas map[string]pg.Schema) ConvertOption {
return func(o *convertOptions) {
o.schemas = schemas
}
}

// ConvertWithSchemas converts a CEL AST to a PostgreSQL SQL WHERE clause condition,
// using schema information to properly handle JSON/JSONB fields.
func ConvertWithSchemas(ast *cel.Ast, schemas map[string]pg.Schema) (string, error) {
// Convert converts a CEL AST to a PostgreSQL SQL WHERE clause condition.
// Options can be provided to configure the conversion behavior.
//
// Example without options:
//
// sql, err := cel2sql.Convert(ast)
//
// Example with schema information for JSON/JSONB support:
//
// sql, err := cel2sql.Convert(ast, cel2sql.WithSchemas(schemas))
func Convert(ast *cel.Ast, opts ...ConvertOption) (string, error) {
options := &convertOptions{}
for _, opt := range opts {
opt(options)
}

checkedExpr, err := cel.AstToCheckedExpr(ast)
if err != nil {
return "", err
}
un := &converter{
typeMap: checkedExpr.TypeMap,
schemas: schemas,
schemas: options.schemas,
}
if err := un.visit(checkedExpr.Expr); err != nil {
return "", err
Expand Down Expand Up @@ -107,6 +136,66 @@ func (con *converter) getTableAndFieldFromSelectChain(expr *exprpb.Expr) (string
return "", "", false
}

// isFieldJSONB checks if a field in a table is specifically JSONB (vs JSON) using schema information
func (con *converter) isFieldJSONB(tableName, fieldName string) bool {
if con.schemas == nil {
return false
}

schema, ok := con.schemas[tableName]
if !ok {
return false
}

for _, field := range schema {
if field.Name == fieldName {
return field.IsJSONB
}
}

return false
}

// isFieldArray checks if a field in a table is an array using schema information
func (con *converter) isFieldArray(tableName, fieldName string) bool {
if con.schemas == nil {
return false
}

schema, ok := con.schemas[tableName]
if !ok {
return false
}

for _, field := range schema {
if field.Name == fieldName {
return field.Repeated
}
}

return false
}

// getFieldElementType returns the element type of an array field using schema information
func (con *converter) getFieldElementType(tableName, fieldName string) string {
if con.schemas == nil {
return ""
}

schema, ok := con.schemas[tableName]
if !ok {
return ""
}

for _, field := range schema {
if field.Name == fieldName && field.Repeated {
return field.ElementType
}
}

return ""
}

func (con *converter) visitCall(expr *exprpb.Expr) error {
c := expr.GetCallExpr()
fun := c.GetFunction()
Expand Down
3 changes: 2 additions & 1 deletion examples/load_table_schema/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ func exampleWithPredefinedSchema() {
continue
}

sqlCondition, err := cel2sql.Convert(ast)
// Convert to SQL with schema information for JSON field detection
sqlCondition, err := cel2sql.Convert(ast, cel2sql.WithSchemas(provider.GetSchemas()))
if err != nil {
log.Printf("Error converting %s: %v", expr, err)
continue
Expand Down
72 changes: 36 additions & 36 deletions field_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ func TestFieldNameValidation_Integration(t *testing.T) {
})

tests := []struct {
name string
celExpr string
expectError bool
name string
celExpr string
expectError bool
errorContains string
}{
// Valid field names should work
Expand All @@ -43,33 +43,33 @@ func TestFieldNameValidation_Integration(t *testing.T) {

// Reserved keywords should be rejected
{
name: "reserved keyword: select",
celExpr: `obj.select == "test"`,
expectError: true,
name: "reserved keyword: select",
celExpr: `obj.select == "test"`,
expectError: true,
errorContains: "reserved SQL keyword",
},
{
name: "reserved keyword: where",
celExpr: `obj.where == "test"`,
expectError: true,
name: "reserved keyword: where",
celExpr: `obj.where == "test"`,
expectError: true,
errorContains: "reserved SQL keyword",
},
{
name: "reserved keyword: from",
celExpr: `obj.from == "test"`,
expectError: true,
name: "reserved keyword: from",
celExpr: `obj.from == "test"`,
expectError: true,
errorContains: "reserved SQL keyword",
},
{
name: "reserved keyword: union",
celExpr: `obj.union == "test"`,
expectError: true,
name: "reserved keyword: union",
celExpr: `obj.union == "test"`,
expectError: true,
errorContains: "reserved SQL keyword",
},
{
name: "reserved keyword: drop",
celExpr: `obj.drop == "test"`,
expectError: true,
name: "reserved keyword: drop",
celExpr: `obj.drop == "test"`,
expectError: true,
errorContains: "reserved SQL keyword",
},
}
Expand Down Expand Up @@ -111,10 +111,10 @@ func TestFieldNameValidation_Integration(t *testing.T) {
// TestFieldNameValidation_Identifiers tests identifier validation
func TestFieldNameValidation_Identifiers(t *testing.T) {
tests := []struct {
name string
varName string
celExpr string
expectError bool
name string
varName string
celExpr string
expectError bool
errorContains string
}{
{
Expand All @@ -130,17 +130,17 @@ func TestFieldNameValidation_Identifiers(t *testing.T) {
expectError: false,
},
{
name: "reserved keyword identifier",
varName: "select",
celExpr: `select == "test"`,
expectError: true,
name: "reserved keyword identifier",
varName: "select",
celExpr: `select == "test"`,
expectError: true,
errorContains: "reserved SQL keyword",
},
{
name: "reserved keyword: table",
varName: "table",
celExpr: `table == 5`,
expectError: true,
name: "reserved keyword: table",
varName: "table",
celExpr: `table == 5`,
expectError: true,
errorContains: "reserved SQL keyword",
},
}
Expand Down Expand Up @@ -218,9 +218,9 @@ func TestFieldNameValidation_PreventsSQLInjection(t *testing.T) {
// Comprehensive validation testing is in utils_test.go

maliciousPatterns := []struct {
name string
celExpr string
reason string
name string
celExpr string
reason string
}{
{
name: "cannot use semicolon in field name",
Expand Down Expand Up @@ -254,9 +254,9 @@ func TestFieldNameValidation_PreventsSQLInjection(t *testing.T) {
// TestFieldNameValidation_EdgeCases tests edge cases
func TestFieldNameValidation_EdgeCases(t *testing.T) {
tests := []struct {
name string
fieldName string
shouldPass bool
name string
fieldName string
shouldPass bool
}{
{
name: "single character",
Expand Down
Loading
Loading