Skip to content

Commit 715b6fd

Browse files
authored
fix panic for group by binary type (#1844)
1 parent 3265db3 commit 715b6fd

File tree

5 files changed

+42
-6
lines changed

5 files changed

+42
-6
lines changed

enginetest/queries/order_by_group_by_queries.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,38 @@ var OrderByGroupByScriptTests = []ScriptTest{
8585
},
8686
},
8787
},
88+
{
89+
Name: "Group by BINARY: https://github.com/dolthub/dolt/issues/6179",
90+
SetUpScript: []string{
91+
"create table t (s varchar(100));",
92+
"insert into t values ('abc'), ('def');",
93+
"create table t1 (b binary(3));",
94+
"insert into t1 values ('abc'), ('abc'), ('def'), ('abc'), ('def');",
95+
},
96+
Assertions: []ScriptTestAssertion{
97+
{
98+
Query: "select binary s from t group by binary s order by binary s",
99+
Expected: []sql.Row{
100+
{[]uint8("abc")},
101+
{[]uint8("def")},
102+
},
103+
},
104+
{
105+
Query: "select count(b), b from t1 group by b order by b",
106+
Expected: []sql.Row{
107+
{3, []uint8("abc")},
108+
{2, []uint8("def")},
109+
},
110+
},
111+
{
112+
Query: "select binary s from t group by binary s order by s",
113+
Expected: []sql.Row{
114+
{[]uint8("abc")},
115+
{[]uint8("def")},
116+
},
117+
},
118+
},
119+
},
88120
{
89121
Name: "https://github.com/dolthub/dolt/issues/3016",
90122
SetUpScript: []string{

sql/analyzer/resolve_columns.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ func identifyGroupingAliasReferences(groupBy *plan.GroupBy) (*plan.GroupBy, tran
344344
return e, transform.SameTree, nil
345345
}
346346

347-
if stringContains(projectedAliases, strings.ToLower(uc.Name())) {
347+
if stringContains(projectedAliases, uc.Name()) {
348348
return expression.NewAliasReference(uc.Name()), transform.NewTree, nil
349349
}
350350

sql/analyzer/resolve_orderby.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,7 @@ func pushdownSort(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope,
6464

6565
for _, n := range ns {
6666
col := tableColFromNameable(n)
67-
name := strings.ToLower(n.Name())
68-
if col.Table() == "" && stringContains(childAliases, name) {
67+
if col.Table() == "" && stringContains(childAliases, n.Name()) {
6968
colsFromChild = append(colsFromChild, n.Name())
7069
} else if !tableColsContains(schemaCols, col) {
7170
missingCols = append(missingCols, col)

sql/analyzer/validation_rules.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ func expressionReferencesOnlyGroupBys(groupBys []string, expr sql.Expression) bo
329329
// TODO: this isn't complete, it's overly restrictive. Dependant columns are fine to reference.
330330
default:
331331
if stringContains(groupBys, expr.String()) {
332-
return true
332+
return false
333333
}
334334

335335
if len(expr.Children()) == 0 {
@@ -659,8 +659,9 @@ func validateSubqueryColumns(ctx *sql.Context, a *Analyzer, n sql.Node, scope *p
659659
}
660660

661661
func stringContains(strs []string, target string) bool {
662+
lowerTarget := strings.ToLower(target)
662663
for _, s := range strs {
663-
if s == target {
664+
if lowerTarget == strings.ToLower(s) {
664665
return true
665666
}
666667
}

sql/rowexec/agg.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323

2424
"github.com/dolthub/go-mysql-server/sql"
2525
"github.com/dolthub/go-mysql-server/sql/expression/function/aggregation"
26+
"github.com/dolthub/go-mysql-server/sql/types"
2627
)
2728

2829
type groupByIter struct {
@@ -248,7 +249,10 @@ func groupingKey(
248249

249250
t, isStringType := expr.Type().(sql.StringType)
250251
if isStringType && v != nil {
251-
err = t.Collation().WriteWeightString(hash, v.(string))
252+
v, err = types.ConvertToString(v, t)
253+
if err == nil {
254+
err = t.Collation().WriteWeightString(hash, v.(string))
255+
}
252256
} else {
253257
_, err = fmt.Fprintf(hash, "%v", v)
254258
}

0 commit comments

Comments
 (0)