Skip to content

Commit 1af0f6e

Browse files
authored
Merge pull request #3063 from dolthub/angela/generated_columns
Fix `count(*)` for added generated columns
2 parents abc950c + be6b2de commit 1af0f6e

File tree

4 files changed

+122
-17
lines changed

4 files changed

+122
-17
lines changed

enginetest/queries/generated_columns.go

Lines changed: 116 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ var GeneratedColumnTests = []ScriptTest{
9393
Query: "select * from t1 order by a",
9494
Expected: []sql.Row{{2, 3}, {3, 4}, {4, 5}, {5, 6}},
9595
},
96+
{
97+
Query: "select count(*) from t1",
98+
Expected: []sql.Row{{4}},
99+
},
96100
},
97101
},
98102
{
@@ -171,9 +175,15 @@ var GeneratedColumnTests = []ScriptTest{
171175
"INSERT INTO t16 (pk) VALUES (1), (2)",
172176
"ALTER TABLE t16 ADD COLUMN v2 BIGINT AS (5) STORED FIRST",
173177
},
174-
Assertions: []ScriptTestAssertion{{
175-
Query: "SELECT * FROM t16",
176-
Expected: []sql.Row{{5, 1, 4}, {5, 2, 4}}},
178+
Assertions: []ScriptTestAssertion{
179+
{
180+
Query: "SELECT * FROM t16",
181+
Expected: []sql.Row{{5, 1, 4}, {5, 2, 4}},
182+
},
183+
{
184+
Query: "select count(*) from t16",
185+
Expected: []sql.Row{{2}},
186+
},
177187
},
178188
},
179189
{
@@ -183,9 +193,15 @@ var GeneratedColumnTests = []ScriptTest{
183193
"INSERT INTO t17 VALUES (1, 3), (2, 4)",
184194
"ALTER TABLE t17 ADD COLUMN v2 BIGINT AS (v1 + 2) STORED FIRST",
185195
},
186-
Assertions: []ScriptTestAssertion{{
187-
Query: "SELECT * FROM t17",
188-
Expected: []sql.Row{{5, 1, 3}, {6, 2, 4}}},
196+
Assertions: []ScriptTestAssertion{
197+
{
198+
Query: "SELECT * FROM t17",
199+
Expected: []sql.Row{{5, 1, 3}, {6, 2, 4}},
200+
},
201+
{
202+
Query: "select count(*) from t17",
203+
Expected: []sql.Row{{2}},
204+
},
189205
},
190206
},
191207
{
@@ -275,6 +291,10 @@ var GeneratedColumnTests = []ScriptTest{
275291
Query: "select * from t1 order by b",
276292
Expected: []sql.Row{{1, 2}, {2, 3}},
277293
},
294+
{
295+
Query: "select count(*) from t1",
296+
Expected: []sql.Row{{2}},
297+
},
278298
},
279299
},
280300
{
@@ -347,6 +367,10 @@ var GeneratedColumnTests = []ScriptTest{
347367
Query: "select * from t1 order by b",
348368
Expected: []sql.Row{{1, 2, 3, 4}, {2, 3, 4, 5}},
349369
},
370+
{
371+
Query: "select count(*) from t1",
372+
Expected: []sql.Row{{2}},
373+
},
350374
},
351375
},
352376
{
@@ -427,6 +451,10 @@ var GeneratedColumnTests = []ScriptTest{
427451
Query: "select * from t1 order by b",
428452
Expected: []sql.Row{{1, 2}, {2, 3}},
429453
},
454+
{
455+
Query: "select count(*) from t1",
456+
Expected: []sql.Row{{2}},
457+
},
430458
},
431459
},
432460
{
@@ -580,6 +608,10 @@ var GeneratedColumnTests = []ScriptTest{
580608
" PRIMARY KEY (`a`)\n" +
581609
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
582610
},
611+
{
612+
Query: "select count(*) from t1",
613+
Expected: []sql.Row{{3}},
614+
},
583615
},
584616
},
585617
{
@@ -616,6 +648,10 @@ var GeneratedColumnTests = []ScriptTest{
616648
{1, 2, 4},
617649
},
618650
},
651+
{
652+
Query: "select count(*) from t",
653+
Expected: []sql.Row{{1}},
654+
},
619655
{
620656
Query: "alter table tt add column `col 3` int generated always as (`col 1` + `col 2` + pow(`col 1`, `col 2`)) stored;",
621657
Expected: []sql.Row{
@@ -644,6 +680,10 @@ var GeneratedColumnTests = []ScriptTest{
644680
{1, 2, 4},
645681
},
646682
},
683+
{
684+
Query: "select count(*) from tt",
685+
Expected: []sql.Row{{1}},
686+
},
647687
},
648688
},
649689
{
@@ -680,6 +720,10 @@ var GeneratedColumnTests = []ScriptTest{
680720
{1, 2, 4},
681721
},
682722
},
723+
{
724+
Query: "select count(*) from t",
725+
Expected: []sql.Row{{1}},
726+
},
683727
{
684728
Query: "alter table tt add column `col 3` int generated always as (`col 1` + `col 2` + pow(`col 1`, `col 2`)) virtual;",
685729
Expected: []sql.Row{
@@ -708,6 +752,10 @@ var GeneratedColumnTests = []ScriptTest{
708752
{1, 2, 4},
709753
},
710754
},
755+
{
756+
Query: "select count(*) from tt",
757+
Expected: []sql.Row{{1}},
758+
},
711759
},
712760
},
713761
{
@@ -717,9 +765,15 @@ var GeneratedColumnTests = []ScriptTest{
717765
"INSERT INTO t16 (pk) VALUES (1), (2)",
718766
"ALTER TABLE t16 ADD COLUMN v2 BIGINT AS (5) VIRTUAL FIRST",
719767
},
720-
Assertions: []ScriptTestAssertion{{
721-
Query: "SELECT * FROM t16",
722-
Expected: []sql.Row{{5, 1, 4}, {5, 2, 4}}},
768+
Assertions: []ScriptTestAssertion{
769+
{
770+
Query: "SELECT * FROM t16",
771+
Expected: []sql.Row{{5, 1, 4}, {5, 2, 4}},
772+
},
773+
{
774+
Query: "select count(*) from t16",
775+
Expected: []sql.Row{{2}},
776+
},
723777
},
724778
},
725779
{
@@ -729,9 +783,15 @@ var GeneratedColumnTests = []ScriptTest{
729783
"INSERT INTO t17 VALUES (1, 3), (2, 4)",
730784
"ALTER TABLE t17 ADD COLUMN v2 BIGINT AS (v1 + 2) VIRTUAL FIRST",
731785
},
732-
Assertions: []ScriptTestAssertion{{
733-
Query: "SELECT * FROM t17",
734-
Expected: []sql.Row{{5, 1, 3}, {6, 2, 4}}},
786+
Assertions: []ScriptTestAssertion{
787+
{
788+
Query: "SELECT * FROM t17",
789+
Expected: []sql.Row{{5, 1, 3}, {6, 2, 4}},
790+
},
791+
{
792+
Query: "SELECT count(*) FROM t17",
793+
Expected: []sql.Row{{2}},
794+
},
735795
},
736796
},
737797
{
@@ -872,6 +932,14 @@ var GeneratedColumnTests = []ScriptTest{
872932
Query: "select * from t2 order by c",
873933
Expected: []sql.Row{{1, 0}, {2, 1}, {3, 2}, {6, 5}, {7, 6}},
874934
},
935+
{
936+
Query: "select count(*) from t1",
937+
Expected: []sql.Row{{5}},
938+
},
939+
{
940+
Query: "select count(*) from t2",
941+
Expected: []sql.Row{{5}},
942+
},
875943
},
876944
},
877945
{
@@ -891,6 +959,10 @@ var GeneratedColumnTests = []ScriptTest{
891959
{2, types.MustJSON(`{"a": 1}`), nil},
892960
{3, types.MustJSON(`{"b": "300"}`), 300}},
893961
},
962+
{
963+
Query: "select count(*) from t1",
964+
Expected: []sql.Row{{3}},
965+
},
894966
},
895967
},
896968
{
@@ -911,6 +983,10 @@ var GeneratedColumnTests = []ScriptTest{
911983
{"ghi", "", "ghi"},
912984
},
913985
},
986+
{
987+
Query: "select count(*) from t1",
988+
Expected: []sql.Row{{3}},
989+
},
914990
},
915991
},
916992
{
@@ -951,6 +1027,10 @@ var GeneratedColumnTests = []ScriptTest{
9511027
{2, 3, 4, 5},
9521028
},
9531029
},
1030+
{
1031+
Query: "select count(*) from t",
1032+
Expected: []sql.Row{{3}},
1033+
},
9541034
},
9551035
},
9561036
{
@@ -1028,6 +1108,10 @@ var GeneratedColumnTests = []ScriptTest{
10281108
Query: "select * from t1 order by a",
10291109
Expected: []sql.Row{{1, 2, 3}, {3, 4, 7}},
10301110
},
1111+
{
1112+
Query: "select count(*) from t1",
1113+
Expected: []sql.Row{{2}},
1114+
},
10311115
},
10321116
},
10331117
{
@@ -1092,6 +1176,10 @@ var GeneratedColumnTests = []ScriptTest{
10921176
{3, 4, 7},
10931177
},
10941178
},
1179+
{
1180+
Query: "select count(*) from t1",
1181+
Expected: []sql.Row{{2}},
1182+
},
10951183
{
10961184
Query: "select * from t1 where c = 6",
10971185
Expected: []sql.Row{
@@ -1121,6 +1209,10 @@ var GeneratedColumnTests = []ScriptTest{
11211209
Query: "select * from t1 where v = 2",
11221210
Expected: []sql.Row{{"{\"a\": 2}", 2}},
11231211
},
1212+
{
1213+
Query: "select count(*) from t1",
1214+
Expected: []sql.Row{{3}},
1215+
},
11241216
{
11251217
Query: "update t1 set j = '{\"a\": 5}' where v = 2",
11261218
Expected: []sql.Row{{NewUpdateResult(1, 1)}},
@@ -1217,6 +1309,10 @@ var GeneratedColumnTests = []ScriptTest{
12171309
Query: "select * from t1 order by b",
12181310
Expected: []sql.Row{{1, 2, 3, 4}, {2, 3, 4, 5}},
12191311
},
1312+
{
1313+
Query: "select count(*) from t1",
1314+
Expected: []sql.Row{{2}},
1315+
},
12201316
},
12211317
},
12221318
{
@@ -1301,6 +1397,10 @@ var GeneratedColumnTests = []ScriptTest{
13011397
Query: "insert into t2 (a) values (1), (2)",
13021398
Expected: []sql.Row{{types.NewOkResult(2)}},
13031399
},
1400+
{
1401+
Query: "select count(*) from t2",
1402+
Expected: []sql.Row{{2}},
1403+
},
13041404
{
13051405
Query: "select * from t2 order by a",
13061406
Expected: []sql.Row{
@@ -1318,6 +1418,10 @@ var GeneratedColumnTests = []ScriptTest{
13181418
Query: "insert into t3 (a) values (1), (2)",
13191419
Expected: []sql.Row{{types.NewOkResult(2)}},
13201420
},
1421+
{
1422+
Query: "select count(*) from t3",
1423+
Expected: []sql.Row{{2}},
1424+
},
13211425
{
13221426
Query: "select * from t3 order by a",
13231427
Expected: []sql.Row{

sql/analyzer/catalog.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,8 @@ func getStatisticsTable(table sql.Table, prevTable sql.Table) (sql.StatisticsTab
467467
return t, true
468468
case sql.TableNode:
469469
return getStatisticsTable(t.UnderlyingTable(), table)
470+
case sql.TableWrapper:
471+
return getStatisticsTable(t.Underlying(), table)
470472
default:
471473
return nil, false
472474
}

sql/analyzer/replace_count_star.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ func replaceCountStar(ctx *sql.Context, a *Analyzer, n sql.Node, _ *plan.Scope,
9595
return n, transform.SameTree, nil
9696
}
9797

98-
if statsTable, ok := rt.Table.(sql.StatisticsTable); ok {
98+
if statsTable, ok := getStatisticsTable(rt.Table, nil); ok {
9999
rowCnt, exact, err := statsTable.RowCount(ctx)
100100
if err == nil && exact {
101101
return plan.NewProject(

sql/analyzer/symbol_resolution.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,16 +208,15 @@ func pruneTableCols(
208208
}
209209

210210
// Don't prune columns if they're needed by a virtual column
211-
virtualColDeps := make(map[tableCol]int)
211+
virtualColDeps := make(map[string]int)
212212
if !selectStar { // if selectStar, we're adding all columns anyway
213213
if vct, isVCT := n.WrappedTable().(*plan.VirtualColumnTable); isVCT {
214214
for _, projection := range vct.Projections {
215215
transform.InspectExpr(projection, func(e sql.Expression) bool {
216216
if cd, isCD := e.(*sql.ColumnDefaultValue); isCD {
217217
transform.InspectExpr(cd.Expr, func(e sql.Expression) bool {
218218
if gf, ok := e.(*expression.GetField); ok {
219-
c := newTableCol(gf.Table(), gf.Name())
220-
virtualColDeps[c]++
219+
virtualColDeps[gf.Name()]++
221220
}
222221
return false
223222
})
@@ -232,7 +231,7 @@ func pruneTableCols(
232231
source := strings.ToLower(table.Name())
233232
for _, col := range table.Schema() {
234233
c := newTableCol(source, col.Name)
235-
if selectStar || parentCols[c] > 0 || virtualColDeps[c] > 0 {
234+
if selectStar || parentCols[c] > 0 || virtualColDeps[c.Name()] > 0 {
236235
cols = append(cols, c.col)
237236
}
238237
}

0 commit comments

Comments
 (0)