diff --git a/go.mod b/go.mod index 88040b7ebb..e7e6666a78 100644 --- a/go.mod +++ b/go.mod @@ -6,11 +6,11 @@ require ( github.com/PuerkitoBio/goquery v1.8.1 github.com/cockroachdb/apd/v2 v2.0.3-0.20200518165714-d020e156310a github.com/cockroachdb/errors v1.7.5 - github.com/dolthub/dolt/go v0.40.5-0.20250710172911-4fc9871c00d8 + github.com/dolthub/dolt/go v0.40.5-0.20250711003334-ecce50ff234b github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20241119094239-f4e529af734d github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 github.com/dolthub/go-icu-regex v0.0.0-20250327004329-6799764f2dad - github.com/dolthub/go-mysql-server v0.20.1-0.20250707155350-de1be10bcc53 + github.com/dolthub/go-mysql-server v0.20.1-0.20250711173737-d3ba56fd599d github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216 github.com/dolthub/vitess v0.0.0-20250611225316-90a5898bfe26 github.com/fatih/color v1.13.0 diff --git a/go.sum b/go.sum index fd67d3e3a8..a50cd244e6 100644 --- a/go.sum +++ b/go.sum @@ -258,6 +258,8 @@ github.com/dolthub/aws-sdk-go-ini-parser v0.0.0-20250305001723-2821c37f6c12 h1:I github.com/dolthub/aws-sdk-go-ini-parser v0.0.0-20250305001723-2821c37f6c12/go.mod h1:rN7X8BHwkjPcfMQQ2QTAq/xM3leUSGLfb+1Js7Y6TVo= github.com/dolthub/dolt/go v0.40.5-0.20250710172911-4fc9871c00d8 h1:He85ATMhdHOSD4swSS33aebpRBeRKJBrP8QA8Bw0H1s= github.com/dolthub/dolt/go v0.40.5-0.20250710172911-4fc9871c00d8/go.mod h1:8QI1n1vOwOh4ks23dzyE2KBCHXIpHvJCLdsc3cKcjxQ= +github.com/dolthub/dolt/go v0.40.5-0.20250711003334-ecce50ff234b h1:BwMMFillst0qP5BU1d1ZTC0dcyVnlyQn/M+y9wTE+Xg= +github.com/dolthub/dolt/go v0.40.5-0.20250711003334-ecce50ff234b/go.mod h1:XZslA/JSf53Xo1Gx/SwacF34m18/ZDL6D8tEt2nvTas= github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20241119094239-f4e529af734d h1:gO9+wrmNHXukPNCO1tpfCcXIdMlW/qppbUStfLvqz/U= github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20241119094239-f4e529af734d/go.mod h1:L5RDYZbC9BBWmoU2+TjTekeqqhFXX5EqH9ln00O0stY= github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 h1:u3PMzfF8RkKd3lB9pZ2bfn0qEG+1Gms9599cr0REMww= @@ -268,6 +270,8 @@ github.com/dolthub/go-icu-regex v0.0.0-20250327004329-6799764f2dad h1:66ZPawHszN github.com/dolthub/go-icu-regex v0.0.0-20250327004329-6799764f2dad/go.mod h1:ylU4XjUpsMcvl/BKeRRMXSH7e7WBrPXdSLvnRJYrxEA= github.com/dolthub/go-mysql-server v0.20.1-0.20250707155350-de1be10bcc53 h1:VtqXQv3zqO4GNKYLUad1ldIxTMJsIQefEwRmxrwK4Zo= github.com/dolthub/go-mysql-server v0.20.1-0.20250707155350-de1be10bcc53/go.mod h1:zuYoQ3keJHAvWUWMLzbP9anvR32b3sy1Fm8wB8ukNxQ= +github.com/dolthub/go-mysql-server v0.20.1-0.20250711173737-d3ba56fd599d h1:RMhv7QKbf/1FP6fP5VTvXW9rpyc9hjqjSQlgG56Al5E= +github.com/dolthub/go-mysql-server v0.20.1-0.20250711173737-d3ba56fd599d/go.mod h1:zuYoQ3keJHAvWUWMLzbP9anvR32b3sy1Fm8wB8ukNxQ= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI= github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q= github.com/dolthub/ishell v0.0.0-20240701202509-2b217167d718 h1:lT7hE5k+0nkBdj/1UOSFwjWpNxf+LCApbRHgnCA17XE= diff --git a/server/analyzer/init.go b/server/analyzer/init.go index 58c0dfb7b8..75e0d40f79 100644 --- a/server/analyzer/init.go +++ b/server/analyzer/init.go @@ -16,8 +16,11 @@ package analyzer import ( "github.com/dolthub/go-mysql-server/sql/analyzer" + "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/plan" "github.com/dolthub/go-mysql-server/sql/planbuilder" + + pgexpression "github.com/dolthub/doltgresql/server/expression" ) // IDs are basically arbitrary, we just need to ensure that they do not conflict with existing IDs @@ -104,6 +107,8 @@ func initEngine() { plan.ValidateForeignKeyDefinition = validateForeignKeyDefinition planbuilder.IsAggregateFunc = IsAggregateFunc + + expression.DefaultExpressionFactory = pgexpression.PostgresExpressionFactory{} } // IsAggregateFunc checks if the given function name is an aggregate function. This is the entire set supported by diff --git a/server/expression/expr_factory.go b/server/expression/expr_factory.go new file mode 100644 index 0000000000..f9433fab4d --- /dev/null +++ b/server/expression/expr_factory.go @@ -0,0 +1,36 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression + +import ( + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" +) + +// PostgresExpressionFactory implements the expression.ExpressionFactory interface and +// allows callers to produce expressions that have custom behavior for Postgres. +type PostgresExpressionFactory struct{} + +var _ expression.ExpressionFactory = (*PostgresExpressionFactory)(nil) + +// NewIsNull implements the expression.ExpressionFactory interface. +func (m PostgresExpressionFactory) NewIsNull(e sql.Expression) sql.Expression { + return NewIsNull(e) +} + +// NewIsNotNull implements the expression.ExpressionFactory interface. +func (m PostgresExpressionFactory) NewIsNotNull(e sql.Expression) sql.Expression { + return NewIsNotNull(e) +} diff --git a/server/expression/is_not_null.go b/server/expression/is_not_null.go new file mode 100644 index 0000000000..d6fc847ac7 --- /dev/null +++ b/server/expression/is_not_null.go @@ -0,0 +1,95 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression + +import ( + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" + + pgtypes "github.com/dolthub/doltgresql/server/types" +) + +// IsNotNull is an implementation of sql.Expression for the IS NOT NULL operator +// and includes Postgres-specific logic for handling records and composites. +type IsNotNull struct { + expression.UnaryExpression +} + +var _ sql.Expression = (*IsNotNull)(nil) +var _ sql.CollationCoercible = (*IsNotNull)(nil) +var _ sql.IsNotNullExpression = (*IsNotNull)(nil) + +// NewIsNotNull creates a new IsNotNull expression. +func NewIsNotNull(child sql.Expression) *IsNotNull { + return &IsNotNull{expression.UnaryExpression{Child: child}} +} + +// IsNotNullExpression implements the sql.IsNotNullExpression interface. This function exists primarily +// to ensure the IsNotNullExpression interface has a unique signature. +func (e *IsNotNull) IsNotNullExpression() bool { + return true +} + +// Type implements the Expression interface. +func (e *IsNotNull) Type() sql.Type { + return pgtypes.Bool +} + +// CollationCoercibility implements the interface sql.CollationCoercible. +func (*IsNotNull) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { + return sql.Collation_binary, 5 +} + +// IsNullable implements the Expression interface. +func (e *IsNotNull) IsNullable() bool { + return false +} + +// Eval implements the Expression interface. +func (e *IsNotNull) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + v, err := e.Child.Eval(ctx, row) + if err != nil { + return nil, err + } + + // Slices of typed values (e.g. Record and Composite types in Postgres) evaluate + // true for IS NOT NULL only if ALL of their entries are not NULL. + if tupleValue, ok := v.([]pgtypes.RecordValue); ok { + for _, typedValue := range tupleValue { + if typedValue.Value == nil { + return false, nil + } + } + return true, nil + } + + return v != nil, nil +} + +func (e *IsNotNull) String() string { + return e.Child.String() + " IS NOT NULL" +} + +func (e *IsNotNull) DebugString() string { + return sql.DebugString(e.Child) + " IS NOT NULL" +} + +// WithChildren implements the Expression interface. +func (e *IsNotNull) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 1) + } + return NewIsNotNull(children[0]), nil +} diff --git a/server/expression/is_null.go b/server/expression/is_null.go new file mode 100644 index 0000000000..8a40d6d409 --- /dev/null +++ b/server/expression/is_null.go @@ -0,0 +1,95 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression + +import ( + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" + + pgtypes "github.com/dolthub/doltgresql/server/types" +) + +// IsNull is an implementation of sql.Expression for the IS NULL operator and +// includes Postgres-specific logic for handling records and composites. +type IsNull struct { + expression.UnaryExpression +} + +var _ sql.Expression = (*IsNull)(nil) +var _ sql.CollationCoercible = (*IsNull)(nil) +var _ sql.IsNullExpression = (*IsNull)(nil) + +// NewIsNull creates a new IsNull expression. +func NewIsNull(child sql.Expression) *IsNull { + return &IsNull{expression.UnaryExpression{Child: child}} +} + +// IsNullExpression implements the sql.IsNullExpression interface. This function exists primarily +// to ensure the IsNullExpression interface has a unique signature. +func (e *IsNull) IsNullExpression() bool { + return true +} + +// Type implements the Expression interface. +func (e *IsNull) Type() sql.Type { + return pgtypes.Bool +} + +// CollationCoercibility implements the interface sql.CollationCoercible. +func (*IsNull) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { + return sql.Collation_binary, 5 +} + +// IsNullable implements the Expression interface. +func (e *IsNull) IsNullable() bool { + return false +} + +// Eval implements the Expression interface. +func (e *IsNull) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + v, err := e.Child.Eval(ctx, row) + if err != nil { + return nil, err + } + + // Slices of typed values (e.g. Record and Composite types in Postgres) evaluate + // to NULL if all of their entries are NULL. + if tupleValue, ok := v.([]pgtypes.RecordValue); ok { + for _, typedValue := range tupleValue { + if typedValue.Value != nil { + return false, nil + } + } + return true, nil + } + + return v == nil, nil +} + +func (e *IsNull) String() string { + return e.Child.String() + " IS NULL" +} + +func (e *IsNull) DebugString() string { + return sql.DebugString(e.Child) + " IS NULL" +} + +// WithChildren implements the Expression interface. +func (e *IsNull) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(e, len(children), 1) + } + return NewIsNull(children[0]), nil +} diff --git a/server/functions/record.go b/server/functions/record.go index 8a1ebc4389..377116f2e1 100644 --- a/server/functions/record.go +++ b/server/functions/record.go @@ -53,7 +53,7 @@ var record_out = framework.Function1{ Callable: func(ctx *sql.Context, t [2]*pgtypes.DoltgresType, val any) (any, error) { values, ok := val.([]pgtypes.RecordValue) if !ok { - return nil, fmt.Errorf("expected []any, but got %T", val) + return nil, fmt.Errorf("expected []RecordValue, but got %T", val) } return pgtypes.RecordToString(ctx, values) }, @@ -79,7 +79,7 @@ var record_send = framework.Function1{ Callable: func(ctx *sql.Context, t [2]*pgtypes.DoltgresType, val any) (any, error) { values, ok := val.([]pgtypes.RecordValue) if !ok { - return nil, fmt.Errorf("expected []any, but got %T", val) + return nil, fmt.Errorf("expected []RecordValue, but got %T", val) } output, err := pgtypes.RecordToString(ctx, values) if err != nil { diff --git a/server/types/record.go b/server/types/record.go index f555498e66..f5fbdd01a7 100644 --- a/server/types/record.go +++ b/server/types/record.go @@ -15,6 +15,8 @@ package types import ( + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/doltgresql/core/id" ) @@ -56,9 +58,9 @@ var Record = &DoltgresType{ CompareFunc: toFuncID("-"), } -// RecordValue holds the value of a single field in a record, including type information for the -// field value. +// RecordValue represents a single value in a record, along with its +// associated type. type RecordValue struct { Value any - Type *DoltgresType + Type sql.Type } diff --git a/server/types/utils.go b/server/types/utils.go index e041b5e5d3..38cc9910bc 100644 --- a/server/types/utils.go +++ b/server/types/utils.go @@ -173,11 +173,16 @@ func RecordToString(ctx *sql.Context, fields []RecordValue) (any, error) { continue } - str, err := value.Type.IoOutput(ctx, value.Value) + doltgresType, ok := value.Type.(*DoltgresType) + if !ok { + return nil, fmt.Errorf(`expected *DoltgresType but found: %T`, value.Type) + } + + str, err := doltgresType.IoOutput(ctx, value.Value) if err != nil { return "", err } - if value.Type.ID == Bool.ID { + if doltgresType.ID == Bool.ID { str = string(str[0]) } diff --git a/testing/go/record_test.go b/testing/go/record_test.go index e689ee7855..a93c4b7543 100644 --- a/testing/go/record_test.go +++ b/testing/go/record_test.go @@ -257,20 +257,14 @@ func TestRecords(t *testing.T) { Name: "ROW() casting and type inference", Assertions: []ScriptTestAssertion{ { - // TODO: ERROR: unknown type with oid: 2249 - Skip: true, Query: "SELECT ROW(1, 'a')::record;", Expected: []sql.Row{{"(1,a)"}}, }, { - // TODO: This does not return an error yet - Skip: true, Query: "SELECT ROW(1, 2) = ROW(1, 'two');", ExpectedErr: "invalid input syntax", }, { - // TODO: interface conversion panic - Skip: true, Query: "SELECT ROW(1, 2) = ROW(1, '2');", Expected: []sql.Row{{"t"}}, }, @@ -312,37 +306,41 @@ func TestRecords(t *testing.T) { ExpectedErr: "unequal number of entries", }, { - // TODO: expression.IsNull in GMS is used in this evaluation, but returns - // false for this case, because the record evaluates to []any{nil} - // instead of just nil. - Skip: true, + Query: "SELECT NULL::record IS NULL", + Expected: []sql.Row{{"t"}}, + }, + { Query: "SELECT ROW(NULL) IS NULL", Expected: []sql.Row{{"t"}}, }, { - // TODO: expression.IsNull in GMS is used in this evaluation, but returns - // false for this case, because the record evaluates to []any{nil} - // instead of just nil. - Skip: true, Query: "SELECT ROW(NULL, NULL, NULL) IS NULL;", Expected: []sql.Row{{"t"}}, }, { Query: "SELECT ROW(NULL, 42, NULL) IS NULL;", - Expected: []sql.Row{{0}}, + Expected: []sql.Row{{"f"}}, }, { Query: "SELECT ROW(42) IS NULL", - Expected: []sql.Row{{0}}, + Expected: []sql.Row{{"f"}}, }, { - // TODO: expression.IsNull in GMS is used in this evaluation (wrapped with - // an expression.Not), but returns true for this case, because the record - // evaluates to []any{nil} instead of just nil. - Skip: true, Query: "SELECT ROW(NULL) IS NOT NULL;", Expected: []sql.Row{{"f"}}, }, + { + Query: "SELECT ROW(NULL, NULL) IS NOT NULL;", + Expected: []sql.Row{{"f"}}, + }, + { + Query: "SELECT ROW(NULL, 1) IS NOT NULL;", + Expected: []sql.Row{{"f"}}, + }, + { + Query: "SELECT ROW(1, 1) IS NOT NULL;", + Expected: []sql.Row{{"t"}}, + }, { Query: "SELECT ROW(42) IS NOT NULL;", Expected: []sql.Row{{"t"}},