Skip to content

dolthub/dolt#9641 - Fix BIT Overflow #3148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions enginetest/queries/script_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,74 @@ type ScriptTestAssertion struct {
// Unlike other engine tests, ScriptTests must be self-contained. No other tables are created outside the definition of
// the tests.
var ScriptTests = []ScriptTest{
{
// Regression test for https://github.com/dolthub/dolt/issues/9641
Name: "bit union max1err dolt#9641",
SetUpScript: []string{
"CREATE TABLE report_card (id INT PRIMARY KEY, archived BIT(1))",
"INSERT INTO report_card VALUES (1, 0)",
},
Assertions: []ScriptTestAssertion{
{
// max1err
Query: `SELECT archived FROM report_card WHERE id = 1
UNION ALL
SELECT 48 FROM report_card WHERE id = 1`,
Expected: []sql.Row{{int64(0)}, {int64(48)}},
},
},
},
{
// Regression test for https://github.com/dolthub/dolt/issues/9641
Name: "bit union comprehensive regression test dolt#9641",
Dialect: "mysql",
SetUpScript: []string{
`CREATE TABLE t1 (
id int PRIMARY KEY,
name varchar(254),
archived bit(1) NOT NULL DEFAULT b'0',
archived_directly bit(1) NOT NULL DEFAULT b'0'
)`,
`CREATE TABLE t2 (
id int PRIMARY KEY,
name varchar(254),
archived bit(1) NOT NULL DEFAULT b'0'
)`,
"INSERT INTO t1 VALUES (1, 'Card1', b'0', b'0')",
"INSERT INTO t2 VALUES (2, 'Collection1', b'0')",
},
Assertions: []ScriptTestAssertion{
{
Query: `SELECT *,
COUNT(*) OVER () AS total_count
FROM (
SELECT
5 AS model_ranking,
id,
name,
archived,
archived_directly
FROM t1
WHERE archived_directly = FALSE

UNION ALL

SELECT
7 AS model_ranking,
id,
name,
archived,
NULL AS archived_directly
FROM t2
WHERE archived = FALSE AND id <> 1
) AS dummy_alias`,
Expected: []sql.Row{
{int64(5), int64(1), "Card1", int64(0), int64(0), int64(2)},
{int64(7), int64(2), "Collection1", int64(0), nil, int64(2)},
},
},
},
},
{
Name: "outer join finish unmatched right side",
SetUpScript: []string{
Expand Down
4 changes: 4 additions & 0 deletions sql/analyzer/resolve_unions.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ func finalizeUnions(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope
if err != nil {
return nil, transform.SameTree, err
}

// UNION can return multiple rows even when child queries use LIMIT 1, so disable Max1Row optimization
qFlags.Unset(sql.QFlagMax1Row)

return newN, transform.NewTree, nil
})
}
3 changes: 3 additions & 0 deletions sql/expression/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ func GetConvertToType(l, r sql.Type) string {
if types.IsDecimal(l) || types.IsDecimal(r) {
return ConvertToDecimal
}
if types.IsBit(l) || types.IsBit(r) {
return ConvertToSigned
}
if types.IsUnsigned(l) && types.IsUnsigned(r) {
return ConvertToUnsigned
}
Expand Down
22 changes: 19 additions & 3 deletions sql/iters/rel_iters.go
Original file line number Diff line number Diff line change
Expand Up @@ -601,8 +601,9 @@ func (di *distinctIter) Dispose() {
}

type UnionIter struct {
Cur sql.RowIter
NextIter func(ctx *sql.Context) (sql.RowIter, error)
Cur sql.RowIter
NextIter func(ctx *sql.Context) (sql.RowIter, error)
ResultSchema sql.Schema
}

func (ui *UnionIter) Next(ctx *sql.Context) (sql.Row, error) {
Expand All @@ -620,8 +621,23 @@ func (ui *UnionIter) Next(ctx *sql.Context) (sql.Row, error) {
if err != nil {
return nil, err
}
return ui.Cur.Next(ctx)
res, err = ui.Cur.Next(ctx)
if err != nil {
return nil, err
}
}

// Convert BIT values to Int64 when schema generalization changed types for server engine compatibility
if ui.ResultSchema != nil {
for i, val := range res {
if i < len(ui.ResultSchema) && val != nil {
if converted, _, err := ui.ResultSchema[i].Type.Convert(ctx, val); err == nil {
res[i] = converted
}
}
}
}

return res, err
}

Expand Down
2 changes: 2 additions & 0 deletions sql/plan/set_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"strings"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/types"
)

const (
Expand Down Expand Up @@ -100,6 +101,7 @@ func (s *SetOp) Schema() sql.Schema {
for i := range ls {
c := *ls[i]
if i < len(rs) {
c.Type = types.GeneralizeTypes(ls[i].Type, rs[i].Type)
c.Nullable = ls[i].Nullable || rs[i].Nullable
}
ret[i] = &c
Expand Down
1 change: 1 addition & 0 deletions sql/rowexec/rel.go
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,7 @@ func (b *BaseBuilder) buildSetOp(ctx *sql.Context, s *plan.SetOp, row sql.Row) (
NextIter: func(ctx *sql.Context) (sql.RowIter, error) {
return b.buildNodeExec(ctx, s.Right(), row)
},
ResultSchema: s.Schema(),
}
case plan.IntersectType:
var iter2 sql.RowIter
Expand Down
5 changes: 5 additions & 0 deletions sql/types/bit.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ func (t BitType_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Va
if v == nil {
return sqltypes.NULL, nil
}

// Delegate int64 values to Int64.SQL to prevent server engine []uint8 serialization errors
if int64Val, ok := v.(int64); ok {
return Int64.SQL(ctx, dest, int64Val)
}
value, _, err := t.Convert(ctx, v)
if err != nil {
return sqltypes.Value{}, err
Expand Down
18 changes: 5 additions & 13 deletions sql/types/conversion.go
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,11 @@ func generalizeNumberTypes(a, b sql.Type) sql.Type {
// TODO: Create and handle "Illegal mix of collations" error
// TODO: Handle extended types, like DoltgresType
func GeneralizeTypes(a, b sql.Type) sql.Type {
// BIT types must convert to Int64 in UNION to avoid server engine serialization errors with []uint8
if (IsBit(a) || IsBit(b)) && (IsBit(a) && IsBit(b) || IsNullType(a) || IsNullType(b)) {
return Int64
}

if reflect.DeepEqual(a, b) {
return a
}
Expand Down Expand Up @@ -693,19 +698,6 @@ func GeneralizeTypes(a, b sql.Type) sql.Type {
return LongBlob
}

aIsBit := IsBit(a)
bIsBit := IsBit(b)
if aIsBit && bIsBit {
// TODO: match max bits to max of max bits between a and b
return a.Promote()
}
if aIsBit {
a = Int64
}
if bIsBit {
b = Int64
}

aIsYear := IsYear(a)
bIsYear := IsYear(b)
if aIsYear && bIsYear {
Expand Down
Loading