Skip to content

Commit 30de574

Browse files
authored
Merge pull request #3148 from dolthub/elianddb/9641-union-bit-field-overflow
dolthub/dolt#9641 - Fix BIT Overflow
2 parents 4678e45 + 1713c21 commit 30de574

File tree

6 files changed

+108
-4
lines changed

6 files changed

+108
-4
lines changed

enginetest/queries/script_queries.go

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,88 @@ type ScriptTestAssertion struct {
120120
// Unlike other engine tests, ScriptTests must be self-contained. No other tables are created outside the definition of
121121
// the tests.
122122
var ScriptTests = []ScriptTest{
123+
{
124+
// Regression test for https://github.com/dolthub/dolt/issues/9641
125+
Name: "bit union max1err dolt#9641",
126+
SetUpScript: []string{
127+
"CREATE TABLE report_card (id INT PRIMARY KEY, archived BIT(1))",
128+
"INSERT INTO report_card VALUES (1, 0)",
129+
},
130+
Assertions: []ScriptTestAssertion{
131+
{
132+
// max1err
133+
Query: `SELECT archived FROM report_card WHERE id = 1
134+
UNION ALL
135+
SELECT 48 FROM report_card WHERE id = 1`,
136+
Expected: []sql.Row{{int64(0)}, {int64(48)}},
137+
},
138+
},
139+
},
140+
{
141+
// Regression test for https://github.com/dolthub/dolt/issues/9641
142+
Name: "bit union comprehensive regression test dolt#9641",
143+
Dialect: "mysql",
144+
SetUpScript: []string{
145+
`CREATE TABLE t1 (
146+
id int PRIMARY KEY,
147+
name varchar(254),
148+
archived bit(1) NOT NULL DEFAULT b'0',
149+
archived_directly bit(1) NOT NULL DEFAULT b'0'
150+
)`,
151+
`CREATE TABLE t2 (
152+
id int PRIMARY KEY,
153+
name varchar(254),
154+
archived bit(1) NOT NULL DEFAULT b'0'
155+
)`,
156+
"INSERT INTO t1 VALUES (1, 'Card1', b'0', b'0')",
157+
"INSERT INTO t2 VALUES (2, 'Collection1', b'0')",
158+
},
159+
Assertions: []ScriptTestAssertion{
160+
{
161+
Query: `SELECT *,
162+
COUNT(*) OVER () AS total_count
163+
FROM (
164+
SELECT
165+
5 AS model_ranking,
166+
id,
167+
name,
168+
archived,
169+
archived_directly
170+
FROM t1
171+
WHERE archived_directly = FALSE
172+
173+
UNION ALL
174+
175+
SELECT
176+
7 AS model_ranking,
177+
id,
178+
name,
179+
archived,
180+
NULL AS archived_directly
181+
FROM t2
182+
WHERE archived = FALSE AND id <> 1
183+
) AS dummy_alias`,
184+
Expected: []sql.Row{
185+
{int64(5), int64(1), "Card1", uint64(0), int64(0), int64(2)},
186+
{int64(7), int64(2), "Collection1", uint64(0), nil, int64(2)},
187+
},
188+
},
189+
},
190+
},
191+
{
192+
Name: "bits don't work on server",
193+
Dialect: "mysql",
194+
SetUpScript: []string{
195+
"create table t (b bit(1));",
196+
"insert into t values (1)",
197+
},
198+
Assertions: []ScriptTestAssertion{
199+
{
200+
Query: "select * from t;",
201+
Expected: []sql.Row{{uint64(1)}},
202+
},
203+
},
204+
},
123205
{
124206
Name: "outer join finish unmatched right side",
125207
SetUpScript: []string{

enginetest/server_engine.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,16 @@ func rowIterForGoSqlRows(ctx *sql.Context, sch sql.Schema, rows *gosql.Rows) (sq
377377
func convertValue(ctx *sql.Context, sch sql.Schema, row sql.Row) sql.Row {
378378
for i, col := range sch {
379379
switch col.Type.Type() {
380+
case query.Type_BIT:
381+
if row[i] != nil {
382+
if bytes, ok := row[i].([]byte); ok {
383+
if bt, ok := col.Type.(types.BitType); ok {
384+
if v, _, err := bt.Convert(ctx, bytes); err == nil {
385+
row[i] = v
386+
}
387+
}
388+
}
389+
}
380390
case query.Type_GEOMETRY:
381391
if row[i] != nil {
382392
r, _, err := types.GeometryType{}.Convert(ctx, row[i].([]byte))
@@ -561,9 +571,12 @@ func emptyRowForSchema(sch sql.Schema) ([]any, error) {
561571
func emptyValuePointerForType(t sql.Type) (any, error) {
562572
switch t.Type() {
563573
case query.Type_INT8, query.Type_INT16, query.Type_INT24, query.Type_INT64,
564-
query.Type_BIT, query.Type_YEAR:
574+
query.Type_YEAR:
565575
var i gosql.NullInt64
566576
return &i, nil
577+
case query.Type_BIT:
578+
var b []byte
579+
return &b, nil
567580
case query.Type_INT32:
568581
var i gosql.NullInt32
569582
return &i, nil
@@ -623,8 +636,10 @@ func schemaForRows(rows *gosql.Rows) (sql.Schema, error) {
623636

624637
func convertGoSqlType(columnType *gosql.ColumnType) (sql.Type, error) {
625638
switch strings.ToLower(columnType.DatabaseTypeName()) {
626-
case "tinyint", "smallint", "mediumint", "int", "bigint", "bit":
639+
case "tinyint", "smallint", "mediumint", "int", "bigint":
627640
return types.Int64, nil
641+
case "bit":
642+
return types.MustCreateBitType(1), nil
628643
case "unsigned tinyint", "unsigned smallint", "unsigned mediumint", "unsigned int", "unsigned bigint":
629644
return types.Uint64, nil
630645
case "float", "double":

sql/analyzer/resolve_unions.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,10 @@ func finalizeUnions(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope
9999
if err != nil {
100100
return nil, transform.SameTree, err
101101
}
102+
103+
// UNION can return multiple rows even when child queries use LIMIT 1, so disable Max1Row optimization
104+
qFlags.Unset(sql.QFlagMax1Row)
105+
102106
return newN, transform.NewTree, nil
103107
})
104108
}

sql/expression/convert.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ func GetConvertToType(l, r sql.Type) string {
123123
if types.IsDecimal(l) || types.IsDecimal(r) {
124124
return ConvertToDecimal
125125
}
126+
if types.IsBit(l) || types.IsBit(r) {
127+
return ConvertToSigned
128+
}
126129
if types.IsUnsigned(l) && types.IsUnsigned(r) {
127130
return ConvertToUnsigned
128131
}

sql/plan/set_op.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"strings"
2020

2121
"github.com/dolthub/go-mysql-server/sql"
22+
"github.com/dolthub/go-mysql-server/sql/types"
2223
)
2324

2425
const (
@@ -100,6 +101,7 @@ func (s *SetOp) Schema() sql.Schema {
100101
for i := range ls {
101102
c := *ls[i]
102103
if i < len(rs) {
104+
c.Type = types.GeneralizeTypes(ls[i].Type, rs[i].Type)
103105
c.Nullable = ls[i].Nullable || rs[i].Nullable
104106
}
105107
ret[i] = &c

sql/types/conversion.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -692,7 +692,6 @@ func GeneralizeTypes(a, b sql.Type) sql.Type {
692692
// TODO: match blob length to max of the blob lengths
693693
return LongBlob
694694
}
695-
696695
aIsBit := IsBit(a)
697696
bIsBit := IsBit(b)
698697
if aIsBit && bIsBit {
@@ -705,7 +704,6 @@ func GeneralizeTypes(a, b sql.Type) sql.Type {
705704
if bIsBit {
706705
b = Int64
707706
}
708-
709707
aIsYear := IsYear(a)
710708
bIsYear := IsYear(b)
711709
if aIsYear && bIsYear {

0 commit comments

Comments
 (0)