diff --git a/go.mod b/go.mod index b5c7430a6f..f10d1dcea9 100644 --- a/go.mod +++ b/go.mod @@ -6,11 +6,11 @@ require ( github.com/PuerkitoBio/goquery v1.8.1 github.com/cockroachdb/apd/v2 v2.0.3-0.20200518165714-d020e156310a github.com/cockroachdb/errors v1.7.5 - github.com/dolthub/dolt/go v0.40.5-0.20250415001434-2fdac4b164b9 + github.com/dolthub/dolt/go v0.40.5-0.20250417234826-cbc8ae986979 github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20241119094239-f4e529af734d github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 github.com/dolthub/go-icu-regex v0.0.0-20250327004329-6799764f2dad - github.com/dolthub/go-mysql-server v0.19.1-0.20250414233448-814abccc8b6d + github.com/dolthub/go-mysql-server v0.19.1-0.20250417231730-fcd059390dd6 github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216 github.com/dolthub/vitess v0.0.0-20250414165810-f0031a6472b7 github.com/fatih/color v1.13.0 diff --git a/go.sum b/go.sum index f641bdd95d..9dec0de1d9 100644 --- a/go.sum +++ b/go.sum @@ -256,8 +256,8 @@ github.com/docker/go-connections v0.4.0/go.mod h1:Gbd7IOopHjR8Iph03tsViu4nIes5Xh github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/dolthub/aws-sdk-go-ini-parser v0.0.0-20250305001723-2821c37f6c12 h1:IdqX7J8vi/Kn3T3Ee0VzqnLqwFmgA2hr8WZETPcQjfM= github.com/dolthub/aws-sdk-go-ini-parser v0.0.0-20250305001723-2821c37f6c12/go.mod h1:rN7X8BHwkjPcfMQQ2QTAq/xM3leUSGLfb+1Js7Y6TVo= -github.com/dolthub/dolt/go v0.40.5-0.20250415001434-2fdac4b164b9 h1:ZWDHwpVJTO8WQPLqhe63E1oXPPq0SVQVD6OrlO2rQ4Q= -github.com/dolthub/dolt/go v0.40.5-0.20250415001434-2fdac4b164b9/go.mod h1:ugNKxGCsAzevsXPHmBsDAnZ+yU0ty4BLNHQ9pwPKh1A= +github.com/dolthub/dolt/go v0.40.5-0.20250417234826-cbc8ae986979 h1:AtvujGDcRYBTSPL0G2IWZdoXOh3S1COdgTyX/o85RPI= +github.com/dolthub/dolt/go v0.40.5-0.20250417234826-cbc8ae986979/go.mod h1:O4i55d+mBJhvlbUvpz7wZ+YaNWGb41t4KBL/6q+exVo= github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20241119094239-f4e529af734d h1:gO9+wrmNHXukPNCO1tpfCcXIdMlW/qppbUStfLvqz/U= github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20241119094239-f4e529af734d/go.mod h1:L5RDYZbC9BBWmoU2+TjTekeqqhFXX5EqH9ln00O0stY= github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 h1:u3PMzfF8RkKd3lB9pZ2bfn0qEG+1Gms9599cr0REMww= @@ -266,8 +266,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U= github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0= github.com/dolthub/go-icu-regex v0.0.0-20250327004329-6799764f2dad h1:66ZPawHszNu37VPQckdhX1BPPVzREsGgNxQeefnlm3g= github.com/dolthub/go-icu-regex v0.0.0-20250327004329-6799764f2dad/go.mod h1:ylU4XjUpsMcvl/BKeRRMXSH7e7WBrPXdSLvnRJYrxEA= -github.com/dolthub/go-mysql-server v0.19.1-0.20250414233448-814abccc8b6d h1:wDCmc0OCBo1+Bc8jWTdv1Ps8YUg9FcJxdOHlc9ZuRCY= -github.com/dolthub/go-mysql-server v0.19.1-0.20250414233448-814abccc8b6d/go.mod h1:bOB8AJAqzKtYw/xDA9UININeFU1Il4eYEqSP3lZOqIA= +github.com/dolthub/go-mysql-server v0.19.1-0.20250417231730-fcd059390dd6 h1:mqr/9yLWqM8eMxxQ9/ruRpdJmyC6hR0K9AMiafqnLiE= +github.com/dolthub/go-mysql-server v0.19.1-0.20250417231730-fcd059390dd6/go.mod h1:bOB8AJAqzKtYw/xDA9UININeFU1Il4eYEqSP3lZOqIA= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q= github.com/dolthub/ishell v0.0.0-20240701202509-2b217167d718 h1:lT7hE5k+0nkBdj/1UOSFwjWpNxf+LCApbRHgnCA17XE= diff --git a/server/analyzer/type_sanitizer.go b/server/analyzer/type_sanitizer.go index 88f2034b88..59acdef14f 100644 --- a/server/analyzer/type_sanitizer.go +++ b/server/analyzer/type_sanitizer.go @@ -86,12 +86,13 @@ func TypeSanitizer(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, scope // typeSanitizerLiterals handles literal expressions for TypeSanitizer. func typeSanitizerLiterals(ctx context.Context, gmsLiteral *expression.Literal) (sql.Expression, transform.TreeIdentity, error) { // GMS may resolve Doltgres literals and then stick them in GMS literals, so we have to account for that here + // TODO: is this necessary any more? if doltgresType, ok := gmsLiteral.Type().(*pgtypes.DoltgresType); ok { - return pgexprs.NewUnsafeLiteral(gmsLiteral.Value(), doltgresType), transform.NewTree, nil + return pgexprs.NewUnsafeLiteral(gmsLiteral.LiteralValue(), doltgresType), transform.NewTree, nil } switch gmsLiteral.Type().Type() { case query.Type_INT8, query.Type_INT16, query.Type_YEAR: - newVal, _, err := types.Int16.Convert(ctx, gmsLiteral.Value()) + newVal, _, err := types.Int16.Convert(ctx, gmsLiteral.LiteralValue()) if err != nil { return nil, transform.NewTree, err } @@ -100,7 +101,7 @@ func typeSanitizerLiterals(ctx context.Context, gmsLiteral *expression.Literal) } return pgexprs.NewRawLiteralInt16(newVal.(int16)), transform.NewTree, nil case query.Type_INT24, query.Type_INT32: - newVal, _, err := types.Int32.Convert(ctx, gmsLiteral.Value()) + newVal, _, err := types.Int32.Convert(ctx, gmsLiteral.LiteralValue()) if err != nil { return nil, transform.NewTree, err } @@ -109,7 +110,7 @@ func typeSanitizerLiterals(ctx context.Context, gmsLiteral *expression.Literal) } return pgexprs.NewRawLiteralInt32(newVal.(int32)), transform.NewTree, nil case query.Type_INT64, query.Type_ENUM: - newVal, _, err := types.Int64.Convert(ctx, gmsLiteral.Value()) + newVal, _, err := types.Int64.Convert(ctx, gmsLiteral.LiteralValue()) if err != nil { return nil, transform.NewTree, err } @@ -118,7 +119,7 @@ func typeSanitizerLiterals(ctx context.Context, gmsLiteral *expression.Literal) } return pgexprs.NewRawLiteralInt64(newVal.(int64)), transform.NewTree, nil case query.Type_UINT8, query.Type_UINT16, query.Type_UINT24, query.Type_UINT32: - newVal, _, err := types.Uint32.Convert(ctx, gmsLiteral.Value()) + newVal, _, err := types.Uint32.Convert(ctx, gmsLiteral.LiteralValue()) if err != nil { return nil, transform.NewTree, err } @@ -127,7 +128,7 @@ func typeSanitizerLiterals(ctx context.Context, gmsLiteral *expression.Literal) } return pgexprs.NewRawLiteralInt64(int64(newVal.(uint32))), transform.NewTree, nil case query.Type_UINT64, query.Type_SET: - newVal, _, err := types.Uint64.Convert(ctx, gmsLiteral.Value()) + newVal, _, err := types.Uint64.Convert(ctx, gmsLiteral.LiteralValue()) if err != nil { return nil, transform.NewTree, err } @@ -137,7 +138,7 @@ func typeSanitizerLiterals(ctx context.Context, gmsLiteral *expression.Literal) newLiteral, err := pgexprs.NewNumericLiteral(strconv.FormatUint(newVal.(uint64), 10)) return newLiteral, transform.NewTree, err case query.Type_FLOAT32: - newVal, _, err := types.Float32.Convert(ctx, gmsLiteral.Value()) + newVal, _, err := types.Float32.Convert(ctx, gmsLiteral.LiteralValue()) if err != nil { return nil, transform.NewTree, err } @@ -146,7 +147,7 @@ func typeSanitizerLiterals(ctx context.Context, gmsLiteral *expression.Literal) } return pgexprs.NewRawLiteralFloat32(newVal.(float32)), transform.NewTree, nil case query.Type_FLOAT64: - newVal, _, err := types.Float64.Convert(ctx, gmsLiteral.Value()) + newVal, _, err := types.Float64.Convert(ctx, gmsLiteral.LiteralValue()) if err != nil { return nil, transform.NewTree, err } @@ -155,13 +156,13 @@ func typeSanitizerLiterals(ctx context.Context, gmsLiteral *expression.Literal) } return pgexprs.NewRawLiteralFloat64(newVal.(float64)), transform.NewTree, nil case query.Type_DECIMAL: - dec, ok := gmsLiteral.Value().(decimal.Decimal) + dec, ok := gmsLiteral.LiteralValue().(decimal.Decimal) if !ok { - return nil, transform.NewTree, errors.Errorf("SANITIZER: expected decimal type: %T", gmsLiteral.Value()) + return nil, transform.NewTree, errors.Errorf("SANITIZER: expected decimal type: %T", gmsLiteral.LiteralValue()) } return pgexprs.NewRawLiteralNumeric(dec), transform.NewTree, nil case query.Type_DATE, query.Type_DATETIME, query.Type_TIMESTAMP: - newVal, _, err := types.Datetime.Convert(ctx, gmsLiteral.Value()) + newVal, _, err := types.Datetime.Convert(ctx, gmsLiteral.LiteralValue()) if err != nil { return nil, transform.NewTree, err } @@ -170,13 +171,13 @@ func typeSanitizerLiterals(ctx context.Context, gmsLiteral *expression.Literal) } return pgexprs.NewRawLiteralTimestamp(newVal.(time.Time)), transform.NewTree, nil case query.Type_CHAR, query.Type_VARCHAR, query.Type_TEXT: - str, ok := gmsLiteral.Value().(string) + str, ok := gmsLiteral.LiteralValue().(string) if !ok { - return nil, transform.NewTree, errors.Errorf("SANITIZER: expected string type: %T", gmsLiteral.Value()) + return nil, transform.NewTree, errors.Errorf("SANITIZER: expected string type: %T", gmsLiteral.LiteralValue()) } return pgexprs.NewUnknownLiteral(str), transform.NewTree, nil case query.Type_BINARY, query.Type_VARBINARY, query.Type_BLOB: - newVal := gmsLiteral.Value() + newVal := gmsLiteral.LiteralValue() if newVal == nil { return pgexprs.NewNullLiteral(), transform.NewTree, nil } else if str, ok := newVal.(string); ok { @@ -184,15 +185,15 @@ func typeSanitizerLiterals(ctx context.Context, gmsLiteral *expression.Literal) } else if b, ok := newVal.([]byte); ok { return pgexprs.NewUnknownLiteral(string(b)), transform.NewTree, nil } - return nil, transform.NewTree, errors.Errorf("SANITIZER: invalid binary type: %T", gmsLiteral.Value()) + return nil, transform.NewTree, errors.Errorf("SANITIZER: invalid binary type: %T", gmsLiteral.LiteralValue()) case query.Type_JSON: - newVal := gmsLiteral.Value() + newVal := gmsLiteral.LiteralValue() if newVal == nil { return pgexprs.NewNullLiteral(), transform.NewTree, nil } str, ok := newVal.(string) if !ok { - return nil, transform.NewTree, errors.Errorf("SANITIZER: expected string type: %T", gmsLiteral.Value()) + return nil, transform.NewTree, errors.Errorf("SANITIZER: expected string type: %T", gmsLiteral.LiteralValue()) } return pgexprs.NewUnknownLiteral(str), transform.NewTree, nil case query.Type_NULL_TYPE: diff --git a/server/ast/limit.go b/server/ast/limit.go index a105defa73..2dcbbf9e26 100644 --- a/server/ast/limit.go +++ b/server/ast/limit.go @@ -16,11 +16,10 @@ package ast import ( "github.com/cockroachdb/errors" - + pgexprs "github.com/dolthub/doltgresql/server/expression" vitess "github.com/dolthub/vitess/go/vt/sqlparser" "github.com/dolthub/doltgresql/postgres/parser/sem/tree" - pgexprs "github.com/dolthub/doltgresql/server/expression" ) // nodeLimit handles *tree.Limit nodes. @@ -40,11 +39,11 @@ func nodeLimit(ctx *Context, node *tree.Limit) (*vitess.Limit, error) { if err != nil { return nil, err } - // GMS is hardcoded to expect vitess.SQLVal for expressions such as `LIMIT 1 OFFSET 1`. - // We need to remove the hard dependency, but for now, we'll just convert our literals to a vitess.SQLVal. + + // MySQL bound checking happens in the parser, we must do it manually if injectedExpr, ok := count.(vitess.InjectedExpr); ok { if literal, ok := injectedExpr.Expression.(*pgexprs.Literal); ok { - l := literal.Value() + l := literal.LiteralValue() limitValue, err := int64ValueForLimit(l) if err != nil { return nil, err @@ -53,13 +52,11 @@ func nodeLimit(ctx *Context, node *tree.Limit) (*vitess.Limit, error) { if limitValue < 0 { return nil, errors.Errorf("LIMIT must be greater than or equal to 0") } - - count = literal.ToVitessLiteral() } } if injectedExpr, ok := offset.(vitess.InjectedExpr); ok { if literal, ok := injectedExpr.Expression.(*pgexprs.Literal); ok { - o := literal.Value() + o := literal.LiteralValue() offsetVal, err := int64ValueForLimit(o) if err != nil { return nil, err @@ -68,10 +65,9 @@ func nodeLimit(ctx *Context, node *tree.Limit) (*vitess.Limit, error) { if offsetVal < 0 { return nil, errors.Errorf("OFFSET must be greater than or equal to 0") } - - offset = literal.ToVitessLiteral() } } + return &vitess.Limit{ Offset: offset, Rowcount: count, diff --git a/server/connection_data.go b/server/connection_data.go index cf1de67aaa..b19f03f665 100644 --- a/server/connection_data.go +++ b/server/connection_data.go @@ -113,15 +113,32 @@ func extractBindVarTypes(queryPlan sql.Node) ([]uint32, error) { types := make([]uint32, 0) var err error - extractBindVars := func(expr sql.Expression) bool { + var extractBindVars func(n sql.Node, expr sql.Expression) bool + extractBindVars = func(n sql.Node, expr sql.Expression) bool { if err != nil { return false } + switch e := expr.(type) { + // Subquery doesn't walk its Node child via Expressions, so we must walk it separately here + case *plan.Subquery: + transform.InspectExpressionsWithNode(e.Query, extractBindVars) case *expression.BindVar: var typOid uint32 if doltgresType, ok := e.Type().(*pgtypes.DoltgresType); ok { typOid = id.Cache().ToOID(doltgresType.ID.AsId()) + } else if _, ok := e.Type().(sql.DeferredType); ok { + // for a deferred type, we can make a guess to its type based on the containing node + switch n.(type) { + case *plan.Limit: + typOid = uint32(oid.T_int4) + default: + typOid, err = VitessTypeToObjectID(e.Type().Type()) + if err != nil { + err = errors.Errorf("could not determine OID for placeholder %s: %w", e.Name, err) + return false + } + } } else { // TODO: should remove usage non doltgres type typOid, err = VitessTypeToObjectID(e.Type().Type()) @@ -163,7 +180,7 @@ func extractBindVarTypes(queryPlan sql.Node) ([]uint32, error) { return true } - transform.InspectExpressions(inspectNode, extractBindVars) + transform.InspectExpressionsWithNode(inspectNode, extractBindVars) return types, err } diff --git a/server/connection_handler.go b/server/connection_handler.go index 022ed63dfe..6ce633cad2 100644 --- a/server/connection_handler.go +++ b/server/connection_handler.go @@ -510,6 +510,9 @@ func (h *ConnectionHandler) handleParse(message *pgproto3.Parse) error { // NOTE: This is used for Prepared Statement Tests only. bindVarTypes, err = extractBindVarTypes(analyzedPlan) if err != nil { + if printErrorStackTraces { + fmt.Printf("Error extracting bind var types: %+v\n", err) + } return err } } diff --git a/server/expression/literal.go b/server/expression/literal.go index d95e989b1e..9ad7317393 100644 --- a/server/expression/literal.go +++ b/server/expression/literal.go @@ -41,6 +41,7 @@ type Literal struct { var _ vitess.Injectable = (*Literal)(nil) var _ sql.Expression = (*Literal)(nil) var _ framework.LiteralInterface = (*Literal)(nil) +var _ sql.LiteralExpression = (*Literal)(nil) // NewNumericLiteral returns a new *Literal containing a NUMERIC value. func NewNumericLiteral(numericValue string) (*Literal, error) { @@ -330,8 +331,8 @@ func (l *Literal) Type() sql.Type { return l.typ } -// Value returns the literal value. -func (l *Literal) Value() any { +// LiteralValue implements the sql.LiteralExpression interface +func (l *Literal) LiteralValue() interface{} { return l.value } diff --git a/testing/go/enginetest/doltgres_engine_test.go b/testing/go/enginetest/doltgres_engine_test.go index f087628160..90834359a6 100755 --- a/testing/go/enginetest/doltgres_engine_test.go +++ b/testing/go/enginetest/doltgres_engine_test.go @@ -611,6 +611,7 @@ func TestScripts(t *testing.T) { "select * from vt where v = cast('def' as char(6));", // incorrect result "select * from vt where v < cast('def' as char(6));", // incorrect result "select * from vt where v >= cast('def' as char(6));", // incorrect result + "histogram bucket merging error for implementor buckets", // with recursive syntax }) defer h.Close() enginetest.TestScripts(t, h) diff --git a/testing/go/prepared_statement_test.go b/testing/go/prepared_statement_test.go index 6c518769de..18d8510ba4 100755 --- a/testing/go/prepared_statement_test.go +++ b/testing/go/prepared_statement_test.go @@ -31,11 +31,9 @@ func TestPreparedPgCatalog(t *testing.T) { } var preparedStatementTests = []ScriptTest{ - { Name: "Expressions without tables", Assertions: []ScriptTestAssertion{ - { Query: "SELECT CONCAT($1::text, $2::text)", BindVars: []any{"hello", "world"}, @@ -56,13 +54,11 @@ var preparedStatementTests = []ScriptTest{ Name: "Expressions with tables", Assertions: []ScriptTestAssertion{ { - Skip: true, // TODO: expected 0 arguments, got 1 Query: "SELECT EXISTS(SELECT 1 FROM pg_namespace WHERE nspname = $1);", BindVars: []any{"public"}, Expected: []sql.Row{{1}}, }, { - Skip: true, // TODO: could not determine OID for placeholder v1: unsupported type: EXPRESSION Query: "SELECT nspname FROM pg_namespace LIMIT $1;", BindVars: []any{1}, Expected: []sql.Row{{"dolt"}},