Skip to content

dolthub/dolt#9641 - Fix BIT Overflow #3148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Aug 18, 2025
68 changes: 68 additions & 0 deletions enginetest/queries/script_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,74 @@ type ScriptTestAssertion struct {
// Unlike other engine tests, ScriptTests must be self-contained. No other tables are created outside the definition of
// the tests.
var ScriptTests = []ScriptTest{
{
// Regression test for https://github.com/dolthub/dolt/issues/9641
Name: "bit union max1err dolt#9641",
SetUpScript: []string{
"CREATE TABLE report_card (id INT PRIMARY KEY, archived BIT(1))",
"INSERT INTO report_card VALUES (1, 0)",
},
Assertions: []ScriptTestAssertion{
{
// max1err
Query: `SELECT archived FROM report_card WHERE id = 1
UNION ALL
SELECT 48 FROM report_card WHERE id = 1`,
Expected: []sql.Row{{int64(0)}, {int64(48)}},
},
},
},
{
// Regression test for https://github.com/dolthub/dolt/issues/9641
Name: "bit union comprehensive regression test dolt#9641",
Dialect: "mysql",
SetUpScript: []string{
`CREATE TABLE t1 (
id int PRIMARY KEY,
name varchar(254),
archived bit(1) NOT NULL DEFAULT b'0',
archived_directly bit(1) NOT NULL DEFAULT b'0'
)`,
`CREATE TABLE t2 (
id int PRIMARY KEY,
name varchar(254),
archived bit(1) NOT NULL DEFAULT b'0'
)`,
"INSERT INTO t1 VALUES (1, 'Card1', b'0', b'0')",
"INSERT INTO t2 VALUES (2, 'Collection1', b'0')",
},
Assertions: []ScriptTestAssertion{
{
Query: `SELECT *,
COUNT(*) OVER () AS total_count
FROM (
SELECT
5 AS model_ranking,
id,
name,
archived,
archived_directly
FROM t1
WHERE archived_directly = FALSE

UNION ALL

SELECT
7 AS model_ranking,
id,
name,
archived,
NULL AS archived_directly
FROM t2
WHERE archived = FALSE AND id <> 1
) AS dummy_alias`,
Expected: []sql.Row{
{int64(5), int64(1), "Card1", uint64(0), int64(0), int64(2)},
{int64(7), int64(2), "Collection1", uint64(0), nil, int64(2)},
},
},
},
},
{
Name: "outer join finish unmatched right side",
SetUpScript: []string{
Expand Down
19 changes: 17 additions & 2 deletions enginetest/server_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,16 @@ func rowIterForGoSqlRows(ctx *sql.Context, sch sql.Schema, rows *gosql.Rows) (sq
func convertValue(ctx *sql.Context, sch sql.Schema, row sql.Row) sql.Row {
for i, col := range sch {
switch col.Type.Type() {
case query.Type_BIT:
if row[i] != nil {
if bytes, ok := row[i].([]byte); ok {
if bt, ok := col.Type.(types.BitType); ok {
if v, _, err := bt.Convert(ctx, bytes); err == nil {
row[i] = v
}
}
}
}
case query.Type_GEOMETRY:
if row[i] != nil {
r, _, err := types.GeometryType{}.Convert(ctx, row[i].([]byte))
Expand Down Expand Up @@ -562,9 +572,12 @@ func emptyRowForSchema(sch sql.Schema) ([]any, error) {
func emptyValuePointerForType(t sql.Type) (any, error) {
switch t.Type() {
case query.Type_INT8, query.Type_INT16, query.Type_INT24, query.Type_INT64,
query.Type_BIT, query.Type_YEAR:
query.Type_YEAR:
var i gosql.NullInt64
return &i, nil
case query.Type_BIT:
var b []byte
return &b, nil
case query.Type_INT32:
var i gosql.NullInt32
return &i, nil
Expand Down Expand Up @@ -624,8 +637,10 @@ func schemaForRows(rows *gosql.Rows) (sql.Schema, error) {

func convertGoSqlType(columnType *gosql.ColumnType) (sql.Type, error) {
switch strings.ToLower(columnType.DatabaseTypeName()) {
case "tinyint", "smallint", "mediumint", "int", "bigint", "bit":
case "tinyint", "smallint", "mediumint", "int", "bigint":
return types.Int64, nil
case "bit":
return types.MustCreateBitType(1), nil
case "unsigned tinyint", "unsigned smallint", "unsigned mediumint", "unsigned int", "unsigned bigint":
return types.Uint64, nil
case "float", "double":
Expand Down
4 changes: 4 additions & 0 deletions sql/analyzer/resolve_unions.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ func finalizeUnions(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope
if err != nil {
return nil, transform.SameTree, err
}

// UNION can return multiple rows even when child queries use LIMIT 1, so disable Max1Row optimization
qFlags.Unset(sql.QFlagMax1Row)

return newN, transform.NewTree, nil
})
}
3 changes: 3 additions & 0 deletions sql/expression/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ func GetConvertToType(l, r sql.Type) string {
if types.IsDecimal(l) || types.IsDecimal(r) {
return ConvertToDecimal
}
if types.IsBit(l) || types.IsBit(r) {
return ConvertToSigned
}
if types.IsUnsigned(l) && types.IsUnsigned(r) {
return ConvertToUnsigned
}
Expand Down
2 changes: 2 additions & 0 deletions sql/plan/set_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"strings"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/types"
)

const (
Expand Down Expand Up @@ -100,6 +101,7 @@ func (s *SetOp) Schema() sql.Schema {
for i := range ls {
c := *ls[i]
if i < len(rs) {
c.Type = types.GeneralizeTypes(ls[i].Type, rs[i].Type)
c.Nullable = ls[i].Nullable || rs[i].Nullable
}
ret[i] = &c
Expand Down
2 changes: 0 additions & 2 deletions sql/types/conversion.go
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,6 @@ func GeneralizeTypes(a, b sql.Type) sql.Type {
// TODO: match blob length to max of the blob lengths
return LongBlob
}

aIsBit := IsBit(a)
bIsBit := IsBit(b)
if aIsBit && bIsBit {
Expand All @@ -705,7 +704,6 @@ func GeneralizeTypes(a, b sql.Type) sql.Type {
if bIsBit {
b = Int64
}

aIsYear := IsYear(a)
bIsYear := IsYear(b)
if aIsYear && bIsYear {
Expand Down
Loading