diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index 7a7c65247c..69939bbb4a 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -120,6 +120,74 @@ type ScriptTestAssertion struct { // Unlike other engine tests, ScriptTests must be self-contained. No other tables are created outside the definition of // the tests. var ScriptTests = []ScriptTest{ + { + // Regression test for https://github.com/dolthub/dolt/issues/9641 + Name: "bit union max1err dolt#9641", + SetUpScript: []string{ + "CREATE TABLE report_card (id INT PRIMARY KEY, archived BIT(1))", + "INSERT INTO report_card VALUES (1, 0)", + }, + Assertions: []ScriptTestAssertion{ + { + // max1err + Query: `SELECT archived FROM report_card WHERE id = 1 + UNION ALL + SELECT 48 FROM report_card WHERE id = 1`, + Expected: []sql.Row{{int64(0)}, {int64(48)}}, + }, + }, + }, + { + // Regression test for https://github.com/dolthub/dolt/issues/9641 + Name: "bit union comprehensive regression test dolt#9641", + Dialect: "mysql", + SetUpScript: []string{ + `CREATE TABLE t1 ( + id int PRIMARY KEY, + name varchar(254), + archived bit(1) NOT NULL DEFAULT b'0', + archived_directly bit(1) NOT NULL DEFAULT b'0' + )`, + `CREATE TABLE t2 ( + id int PRIMARY KEY, + name varchar(254), + archived bit(1) NOT NULL DEFAULT b'0' + )`, + "INSERT INTO t1 VALUES (1, 'Card1', b'0', b'0')", + "INSERT INTO t2 VALUES (2, 'Collection1', b'0')", + }, + Assertions: []ScriptTestAssertion{ + { + Query: `SELECT *, + COUNT(*) OVER () AS total_count + FROM ( + SELECT + 5 AS model_ranking, + id, + name, + archived, + archived_directly + FROM t1 + WHERE archived_directly = FALSE + + UNION ALL + + SELECT + 7 AS model_ranking, + id, + name, + archived, + NULL AS archived_directly + FROM t2 + WHERE archived = FALSE AND id <> 1 + ) AS dummy_alias`, + Expected: []sql.Row{ + {int64(5), int64(1), "Card1", uint64(0), int64(0), int64(2)}, + {int64(7), int64(2), "Collection1", uint64(0), nil, int64(2)}, + }, + }, + }, + }, { Name: "outer join finish unmatched right side", SetUpScript: []string{ diff --git a/sql/analyzer/resolve_unions.go b/sql/analyzer/resolve_unions.go index 2134709321..90df9263e5 100644 --- a/sql/analyzer/resolve_unions.go +++ b/sql/analyzer/resolve_unions.go @@ -99,6 +99,10 @@ func finalizeUnions(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope if err != nil { return nil, transform.SameTree, err } + + // UNION can return multiple rows even when child queries use LIMIT 1, so disable Max1Row optimization + qFlags.Unset(sql.QFlagMax1Row) + return newN, transform.NewTree, nil }) } diff --git a/sql/expression/convert.go b/sql/expression/convert.go index 6a07498e3d..67ffc8b3c0 100644 --- a/sql/expression/convert.go +++ b/sql/expression/convert.go @@ -123,6 +123,9 @@ func GetConvertToType(l, r sql.Type) string { if types.IsDecimal(l) || types.IsDecimal(r) { return ConvertToDecimal } + if types.IsBit(l) || types.IsBit(r) { + return ConvertToSigned + } if types.IsUnsigned(l) && types.IsUnsigned(r) { return ConvertToUnsigned } diff --git a/sql/plan/set_op.go b/sql/plan/set_op.go index 68520a5fba..c9fb6b11e4 100644 --- a/sql/plan/set_op.go +++ b/sql/plan/set_op.go @@ -19,6 +19,7 @@ import ( "strings" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/types" ) const ( @@ -100,6 +101,7 @@ func (s *SetOp) Schema() sql.Schema { for i := range ls { c := *ls[i] if i < len(rs) { + c.Type = types.GeneralizeTypes(ls[i].Type, rs[i].Type) c.Nullable = ls[i].Nullable || rs[i].Nullable } ret[i] = &c diff --git a/sql/types/bit.go b/sql/types/bit.go index 7f9ef77d95..d947c4db50 100644 --- a/sql/types/bit.go +++ b/sql/types/bit.go @@ -199,16 +199,7 @@ func (t BitType_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Va } bitVal := value.(uint64) - var data []byte - for i := uint64(0); i < uint64(t.numOfBits); i += 8 { - data = append(data, byte(bitVal>>i)) - } - for i, j := 0, len(data)-1; i < j; i, j = i+1, j-1 { - data[i], data[j] = data[j], data[i] - } - val := data - - return sqltypes.MakeTrusted(sqltypes.Bit, val), nil + return sqltypes.NewUint64(bitVal), nil } // String implements Type interface. @@ -218,7 +209,8 @@ func (t BitType_) String() string { // Type implements Type interface. func (t BitType_) Type() query.Type { - return sqltypes.Bit + // Use Uint64 for MySQL driver compatibility + return sqltypes.Uint64 } // ValueType implements Type interface. diff --git a/sql/types/conversion.go b/sql/types/conversion.go index b08deb14d3..9cf5961179 100644 --- a/sql/types/conversion.go +++ b/sql/types/conversion.go @@ -693,19 +693,6 @@ func GeneralizeTypes(a, b sql.Type) sql.Type { return LongBlob } - aIsBit := IsBit(a) - bIsBit := IsBit(b) - if aIsBit && bIsBit { - // TODO: match max bits to max of max bits between a and b - return a.Promote() - } - if aIsBit { - a = Int64 - } - if bIsBit { - b = Int64 - } - aIsYear := IsYear(a) bIsYear := IsYear(b) if aIsYear && bIsYear {