Skip to content

Commit 8666dcd

Browse files
author
James Cor
committed
reduce type asserts
1 parent e78b597 commit 8666dcd

File tree

4 files changed

+29
-40
lines changed

4 files changed

+29
-40
lines changed

sql/plan/filter.go

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
package plan
1616

1717
import (
18-
"fmt"
19-
2018
"github.com/dolthub/go-mysql-server/sql"
2119
)
2220

@@ -106,6 +104,9 @@ func (f *Filter) Expressions() []sql.Expression {
106104
type FilterIter struct {
107105
cond sql.Expression
108106
childIter sql.RowIter
107+
108+
cond2 sql.Expression2
109+
childIter2 sql.RowIter2
109110
}
110111

111112
// NewFilterIter creates a new FilterIter.
@@ -136,23 +137,12 @@ func (i *FilterIter) Next(ctx *sql.Context) (sql.Row, error) {
136137
}
137138

138139
func (i *FilterIter) Next2(ctx *sql.Context) (sql.Row2, error) {
139-
ri2, ok := i.childIter.(sql.RowIter2)
140-
if !ok {
141-
panic(fmt.Sprintf("%T is not a sql.RowIter2", i.childIter))
142-
}
143-
144140
for {
145-
row, err := ri2.Next2(ctx)
141+
row, err := i.childIter2.Next2(ctx)
146142
if err != nil {
147143
return nil, err
148144
}
149-
150-
// TODO: write EvaluateCondition2?
151-
cond, isCond2 := i.cond.(sql.Expression2)
152-
if !isCond2 {
153-
panic(fmt.Sprintf("%T does not implement sql.Expression2 interface", i.cond))
154-
}
155-
res, err := cond.Eval2(ctx, row)
145+
res, err := i.cond2.Eval2(ctx, row)
156146
if err != nil {
157147
return nil, err
158148
}
@@ -163,15 +153,17 @@ func (i *FilterIter) Next2(ctx *sql.Context) (sql.Row2, error) {
163153
}
164154

165155
func (i *FilterIter) IsRowIter2(ctx *sql.Context) bool {
166-
if cond, isExpr2 := i.cond.(sql.Expression2); isExpr2 {
167-
if !cond.IsExpr2() {
168-
return false
169-
}
156+
cond, ok := i.cond.(sql.Expression2)
157+
if !ok || !cond.IsExpr2() {
158+
return false
170159
}
171-
if ri2, ok := i.childIter.(sql.RowIter2); ok {
172-
return ri2.IsRowIter2(ctx)
160+
childIter, ok := i.childIter.(sql.RowIter2)
161+
if !ok || !childIter.IsRowIter2(ctx) {
162+
return false
173163
}
174-
return false
164+
i.cond2 = cond
165+
i.childIter2 = childIter
166+
return true
175167
}
176168

177169
// Close implements the RowIter interface.

sql/plan/process.go

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ const (
226226
type TrackedRowIter struct {
227227
node sql.Node
228228
iter sql.RowIter
229+
iter2 sql.RowIter2
229230
onDone NotifyFunc
230231
onNext NotifyFunc
231232
numRows int64
@@ -318,11 +319,7 @@ func (i *TrackedRowIter) Next(ctx *sql.Context) (sql.Row, error) {
318319
}
319320

320321
func (i *TrackedRowIter) Next2(ctx *sql.Context) (sql.Row2, error) {
321-
ri2, ok := i.iter.(sql.RowIter2)
322-
if !ok {
323-
panic(fmt.Sprintf("%T does not implement sql.RowIter2 interface", i.iter))
324-
}
325-
row, err := ri2.Next2(ctx)
322+
row, err := i.iter2.Next2(ctx)
326323
if err != nil {
327324
return nil, err
328325
}
@@ -334,10 +331,12 @@ func (i *TrackedRowIter) Next2(ctx *sql.Context) (sql.Row2, error) {
334331
}
335332

336333
func (i *TrackedRowIter) IsRowIter2(ctx *sql.Context) bool {
337-
if ri2, ok := i.iter.(sql.RowIter2); ok {
338-
return ri2.IsRowIter2(ctx)
334+
iter, ok := i.iter.(sql.RowIter2)
335+
if !ok || !iter.IsRowIter2(ctx) {
336+
return false
339337
}
340-
return false
338+
i.iter2 = iter
339+
return true
341340
}
342341

343342
func (i *TrackedRowIter) Close(ctx *sql.Context) error {

sql/rowexec/transaction_iters.go

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
package rowexec
1616

1717
import (
18-
"fmt"
1918
"io"
2019

2120
"gopkg.in/src-d/go-errors.v1"
@@ -72,6 +71,7 @@ func getLockableTable(table sql.Table) (sql.Lockable, error) {
7271
// during the Close() operation
7372
type TransactionCommittingIter struct {
7473
childIter sql.RowIter
74+
childIter2 sql.RowIter2
7575
transactionDatabase string
7676
autoCommit bool
7777
implicitCommit bool
@@ -101,18 +101,16 @@ func (t *TransactionCommittingIter) Next(ctx *sql.Context) (sql.Row, error) {
101101
}
102102

103103
func (t *TransactionCommittingIter) Next2(ctx *sql.Context) (sql.Row2, error) {
104-
ri2, ok := t.childIter.(sql.RowIter2)
105-
if !ok {
106-
panic(fmt.Sprintf("%T does not implement sql.RowIter2 interface", t.childIter))
107-
}
108-
return ri2.Next2(ctx)
104+
return t.childIter2.Next2(ctx)
109105
}
110106

111107
func (t *TransactionCommittingIter) IsRowIter2(ctx *sql.Context) bool {
112-
if ri2, ok := t.childIter.(sql.RowIter2); ok {
113-
return ri2.IsRowIter2(ctx)
108+
childIter, ok := t.childIter.(sql.RowIter2)
109+
if !ok || !childIter.IsRowIter2(ctx) {
110+
return false
114111
}
115-
return false
112+
t.childIter2 = childIter
113+
return true
116114
}
117115

118116
func (t *TransactionCommittingIter) Close(ctx *sql.Context) error {

sql/table_iter.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ func (i *TableRowIter) Next2(ctx *Context) (Row2, error) {
106106
return nil, err
107107
}
108108
ri2, ok := rows.(RowIter2)
109-
if !ok {
109+
if !ok || !ri2.IsRowIter2(ctx) {
110110
panic(fmt.Sprintf("%T does not implement RowIter2", rows))
111111
}
112112
i.rows2 = ri2

0 commit comments

Comments
 (0)