Skip to content
68 changes: 68 additions & 0 deletions enginetest/queries/script_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,74 @@ type ScriptTestAssertion struct {
// Unlike other engine tests, ScriptTests must be self-contained. No other tables are created outside the definition of
// the tests.
var ScriptTests = []ScriptTest{
{
// Regression test for https://github.com/dolthub/dolt/issues/9641
Name: "bit union max1err dolt#9641",
SetUpScript: []string{
"CREATE TABLE report_card (id INT PRIMARY KEY, archived BIT(1))",
"INSERT INTO report_card VALUES (1, 0)",
},
Assertions: []ScriptTestAssertion{
{
// max1err
Query: `SELECT archived FROM report_card WHERE id = 1
UNION ALL
SELECT 48 FROM report_card WHERE id = 1`,
Expected: []sql.Row{{int64(0)}, {int64(48)}},
},
},
},
{
// Regression test for https://github.com/dolthub/dolt/issues/9641
Name: "bit union comprehensive regression test dolt#9641",
Dialect: "mysql",
SetUpScript: []string{
`CREATE TABLE t1 (
id int PRIMARY KEY,
name varchar(254),
archived bit(1) NOT NULL DEFAULT b'0',
archived_directly bit(1) NOT NULL DEFAULT b'0'
)`,
`CREATE TABLE t2 (
id int PRIMARY KEY,
name varchar(254),
archived bit(1) NOT NULL DEFAULT b'0'
)`,
"INSERT INTO t1 VALUES (1, 'Card1', b'0', b'0')",
"INSERT INTO t2 VALUES (2, 'Collection1', b'0')",
},
Assertions: []ScriptTestAssertion{
{
Query: `SELECT *,
COUNT(*) OVER () AS total_count
FROM (
SELECT
5 AS model_ranking,
id,
name,
archived,
archived_directly
FROM t1
WHERE archived_directly = FALSE

UNION ALL

SELECT
7 AS model_ranking,
id,
name,
archived,
NULL AS archived_directly
FROM t2
WHERE archived = FALSE AND id <> 1
) AS dummy_alias`,
Expected: []sql.Row{
{int64(5), int64(1), "Card1", uint64(0), int64(0), int64(2)},
{int64(7), int64(2), "Collection1", uint64(0), nil, int64(2)},
},
},
},
},
{
Name: "outer join finish unmatched right side",
SetUpScript: []string{
Expand Down
19 changes: 17 additions & 2 deletions enginetest/server_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,16 @@ func rowIterForGoSqlRows(ctx *sql.Context, sch sql.Schema, rows *gosql.Rows) (sq
func convertValue(ctx *sql.Context, sch sql.Schema, row sql.Row) sql.Row {
for i, col := range sch {
switch col.Type.Type() {
case query.Type_BIT:
if row[i] != nil {
if bytes, ok := row[i].([]byte); ok {
if bt, ok := col.Type.(types.BitType); ok {
if v, _, err := bt.Convert(ctx, bytes); err == nil {
row[i] = v
}
}
}
}
case query.Type_GEOMETRY:
if row[i] != nil {
r, _, err := types.GeometryType{}.Convert(ctx, row[i].([]byte))
Expand Down Expand Up @@ -562,9 +572,12 @@ func emptyRowForSchema(sch sql.Schema) ([]any, error) {
func emptyValuePointerForType(t sql.Type) (any, error) {
switch t.Type() {
case query.Type_INT8, query.Type_INT16, query.Type_INT24, query.Type_INT64,
query.Type_BIT, query.Type_YEAR:
query.Type_YEAR:
var i gosql.NullInt64
return &i, nil
case query.Type_BIT:
var b []byte
return &b, nil
case query.Type_INT32:
var i gosql.NullInt32
return &i, nil
Expand Down Expand Up @@ -624,8 +637,10 @@ func schemaForRows(rows *gosql.Rows) (sql.Schema, error) {

func convertGoSqlType(columnType *gosql.ColumnType) (sql.Type, error) {
switch strings.ToLower(columnType.DatabaseTypeName()) {
case "tinyint", "smallint", "mediumint", "int", "bigint", "bit":
case "tinyint", "smallint", "mediumint", "int", "bigint":
return types.Int64, nil
case "bit":
return types.MustCreateBitType(1), nil
case "unsigned tinyint", "unsigned smallint", "unsigned mediumint", "unsigned int", "unsigned bigint":
return types.Uint64, nil
case "float", "double":
Expand Down
4 changes: 4 additions & 0 deletions sql/analyzer/resolve_unions.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ func finalizeUnions(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope
if err != nil {
return nil, transform.SameTree, err
}

// UNION can return multiple rows even when child queries use LIMIT 1, so disable Max1Row optimization
qFlags.Unset(sql.QFlagMax1Row)

return newN, transform.NewTree, nil
})
}
3 changes: 3 additions & 0 deletions sql/expression/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ func GetConvertToType(l, r sql.Type) string {
if types.IsDecimal(l) || types.IsDecimal(r) {
return ConvertToDecimal
}
if types.IsBit(l) || types.IsBit(r) {
return ConvertToSigned
}
if types.IsUnsigned(l) && types.IsUnsigned(r) {
return ConvertToUnsigned
}
Expand Down
2 changes: 2 additions & 0 deletions sql/plan/set_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"strings"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/types"
)

const (
Expand Down Expand Up @@ -100,6 +101,7 @@ func (s *SetOp) Schema() sql.Schema {
for i := range ls {
c := *ls[i]
if i < len(rs) {
c.Type = types.GeneralizeTypes(ls[i].Type, rs[i].Type)
c.Nullable = ls[i].Nullable || rs[i].Nullable
}
ret[i] = &c
Expand Down
2 changes: 0 additions & 2 deletions sql/types/conversion.go
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,6 @@ func GeneralizeTypes(a, b sql.Type) sql.Type {
// TODO: match blob length to max of the blob lengths
return LongBlob
}

aIsBit := IsBit(a)
bIsBit := IsBit(b)
if aIsBit && bIsBit {
Expand All @@ -705,7 +704,6 @@ func GeneralizeTypes(a, b sql.Type) sql.Type {
if bIsBit {
b = Int64
}

aIsYear := IsYear(a)
bIsYear := IsYear(b)
if aIsYear && bIsYear {
Expand Down
Loading