Skip to content

Commit ef4499d

Browse files
authored
Merge pull request #3054 from dolthub/zachmu/project-iters
Support for set-returning functions
2 parents 152a18a + b45e7b8 commit ef4499d

File tree

11 files changed

+211
-12
lines changed

11 files changed

+211
-12
lines changed

sql/core.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,16 @@ type Expression interface {
4545
WithChildren(children ...Expression) (Expression, error)
4646
}
4747

48+
// RowIterExpression is an Expression that returns a RowIter rather than a scalar, used to implement functions that
49+
// return sets.
50+
type RowIterExpression interface {
51+
Expression
52+
// EvalRowIter evaluates the expression, which must be a RowIter
53+
EvalRowIter(ctx *Context, r Row) (RowIter, error)
54+
// ReturnsRowIter returns whether this expression returns a RowIter
55+
ReturnsRowIter() bool
56+
}
57+
4858
// ExpressionWithNodes is an expression that contains nodes as children.
4959
type ExpressionWithNodes interface {
5060
Expression

sql/expression/case.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ func NewCase(expr sql.Expression, branches []CaseBranch, elseExpr sql.Expression
4545

4646
// Type implements the sql.Expression interface.
4747
func (c *Case) Type() sql.Type {
48-
curr := types.Null
48+
var curr sql.Type
49+
curr = types.Null
4950
for _, b := range c.Branches {
5051
curr = types.GeneralizeTypes(curr, b.Value.Type())
5152
}

sql/expression/function/coalesce.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ func (c *Coalesce) Type() sql.Type {
5858
if c.typ != nil {
5959
return c.typ
6060
}
61-
retType := types.Null
61+
62+
var retType sql.Type
63+
retType = types.Null
6264
for i, arg := range c.args {
6365
if arg == nil {
6466
continue

sql/plan/project.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,15 @@ import (
2626
// Project is a projection of certain expression from the children node.
2727
type Project struct {
2828
UnaryNode
29+
// Projections are the expressions to be projected on the row returned by the child node
2930
Projections []sql.Expression
30-
CanDefer bool
31-
deps sql.ColSet
31+
// CanDefer is true when the projection evaluation can be deferred to row spooling, which allows us to avoid a
32+
// separate iterator for the project node.
33+
CanDefer bool
34+
// IncludesNestedIters is true when the projection includes nested iterators because of expressions that return
35+
// a RowIter.
36+
IncludesNestedIters bool
37+
deps sql.ColSet
3238
}
3339

3440
var _ sql.Expressioner = (*Project)(nil)
@@ -202,8 +208,17 @@ func (p *Project) WithExpressions(exprs ...sql.Expression) (sql.Node, error) {
202208
return &np, nil
203209
}
204210

211+
// WithCanDefer returns a new Project with the CanDefer field set to the given value.
205212
func (p *Project) WithCanDefer(canDefer bool) *Project {
206213
np := *p
207214
np.CanDefer = canDefer
208215
return &np
209216
}
217+
218+
// WithIncludesNestedIters returns a new Project with the IncludesNestedIters field set to the given value.
219+
func (p *Project) WithIncludesNestedIters(includesNestedIters bool) *Project {
220+
np := *p
221+
np.IncludesNestedIters = includesNestedIters
222+
np.CanDefer = false
223+
return &np
224+
}

sql/rowexec/rel.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -312,9 +312,10 @@ func (b *BaseBuilder) buildProject(ctx *sql.Context, n *plan.Project, row sql.Ro
312312
}
313313

314314
return sql.NewSpanIter(span, &ProjectIter{
315-
projs: n.Projections,
316-
canDefer: n.CanDefer,
317-
childIter: i,
315+
projs: n.Projections,
316+
canDefer: n.CanDefer,
317+
hasNestedIters: n.IncludesNestedIters,
318+
childIter: i,
318319
}), nil
319320
}
320321

sql/rowexec/rel_iters.go

Lines changed: 147 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"github.com/dolthub/go-mysql-server/sql/hash"
2626
"github.com/dolthub/go-mysql-server/sql/iters"
2727
"github.com/dolthub/go-mysql-server/sql/plan"
28+
"github.com/dolthub/go-mysql-server/sql/transform"
2829
"github.com/dolthub/go-mysql-server/sql/types"
2930
)
3031

@@ -126,16 +127,29 @@ func (i *offsetIter) Close(ctx *sql.Context) error {
126127
var _ sql.RowIter = &iters.JsonTableRowIter{}
127128

128129
type ProjectIter struct {
129-
projs []sql.Expression
130-
canDefer bool
131-
childIter sql.RowIter
130+
projs []sql.Expression
131+
canDefer bool
132+
hasNestedIters bool
133+
nestedState *nestedIterState
134+
childIter sql.RowIter
135+
}
136+
137+
type nestedIterState struct {
138+
projections []sql.Expression
139+
sourceRow sql.Row
140+
iterEvaluators []*RowIterEvaluator
132141
}
133142

134143
func (i *ProjectIter) Next(ctx *sql.Context) (sql.Row, error) {
144+
if i.hasNestedIters {
145+
return i.ProjectRowWithNestedIters(ctx)
146+
}
147+
135148
childRow, err := i.childIter.Next(ctx)
136149
if err != nil {
137150
return nil, err
138151
}
152+
139153
return ProjectRow(ctx, i.projs, childRow)
140154
}
141155

@@ -155,6 +169,136 @@ func (i *ProjectIter) GetChildIter() sql.RowIter {
155169
return i.childIter
156170
}
157171

172+
// ProjectRowWithNestedIters evaluates a set of projections, allowing for nested iterators in the expressions.
173+
func (i *ProjectIter) ProjectRowWithNestedIters(
174+
ctx *sql.Context,
175+
) (sql.Row, error) {
176+
177+
// For the set of iterators, we return one row each element in the longest of the iterators provided.
178+
// Other iterator values will be NULL after they are depleted. All non-iterator fields for the row are returned
179+
// identically for each row in the result set.
180+
if i.nestedState != nil {
181+
row, err := ProjectRow(ctx, i.nestedState.projections, i.nestedState.sourceRow)
182+
if err != nil {
183+
return nil, err
184+
}
185+
186+
nestedIterationFinished := true
187+
for _, evaluator := range i.nestedState.iterEvaluators {
188+
if !evaluator.finished && evaluator.iter != nil {
189+
nestedIterationFinished = false
190+
break
191+
}
192+
}
193+
194+
if nestedIterationFinished {
195+
i.nestedState = nil
196+
return i.ProjectRowWithNestedIters(ctx)
197+
}
198+
199+
return row, nil
200+
}
201+
202+
row, err := i.childIter.Next(ctx)
203+
if err != nil {
204+
return nil, err
205+
}
206+
207+
i.nestedState = &nestedIterState{
208+
sourceRow: row,
209+
}
210+
211+
// We need a new set of projections, with any iterator-returning expressions replaced by new expressions that will
212+
// return the result of the iteration on each call to Eval. We also need to keep a list of all such iterators, so
213+
// that we can tell when they have all finished their iterations.
214+
var rowIterEvaluators []*RowIterEvaluator
215+
newProjs := make([]sql.Expression, len(i.projs))
216+
for i, proj := range i.projs {
217+
p, _, err := transform.Expr(proj, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
218+
if rie, ok := e.(sql.RowIterExpression); ok && rie.ReturnsRowIter() {
219+
ri, err := rie.EvalRowIter(ctx, row)
220+
if err != nil {
221+
return nil, false, err
222+
}
223+
224+
evaluator := &RowIterEvaluator{
225+
iter: ri,
226+
typ: rie.Type(),
227+
}
228+
rowIterEvaluators = append(rowIterEvaluators, evaluator)
229+
return evaluator, transform.NewTree, nil
230+
}
231+
232+
return e, transform.SameTree, nil
233+
})
234+
235+
if err != nil {
236+
return nil, err
237+
}
238+
239+
newProjs[i] = p
240+
}
241+
242+
i.nestedState.projections = newProjs
243+
i.nestedState.iterEvaluators = rowIterEvaluators
244+
245+
return i.ProjectRowWithNestedIters(ctx)
246+
}
247+
248+
// RowIterEvaluator is an expression that returns the next value from a sql.RowIter each time Eval is called.
249+
type RowIterEvaluator struct {
250+
iter sql.RowIter
251+
typ sql.Type
252+
finished bool
253+
}
254+
255+
var _ sql.Expression = (*RowIterEvaluator)(nil)
256+
257+
func (r RowIterEvaluator) Resolved() bool {
258+
return true
259+
}
260+
261+
func (r RowIterEvaluator) String() string {
262+
return "RowIterEvaluator"
263+
}
264+
265+
func (r RowIterEvaluator) Type() sql.Type {
266+
return r.typ
267+
}
268+
269+
func (r RowIterEvaluator) IsNullable() bool {
270+
return true
271+
}
272+
273+
func (r *RowIterEvaluator) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
274+
if r.finished || r.iter == nil {
275+
return nil, nil
276+
}
277+
278+
nextRow, err := r.iter.Next(ctx)
279+
if err != nil {
280+
if errors.Is(err, io.EOF) {
281+
r.finished = true
282+
return nil, nil
283+
}
284+
return nil, err
285+
}
286+
287+
// All of the set-returning functions return a single value per column
288+
return nextRow[0], nil
289+
}
290+
291+
func (r RowIterEvaluator) Children() []sql.Expression {
292+
return nil
293+
}
294+
295+
func (r RowIterEvaluator) WithChildren(children ...sql.Expression) (sql.Expression, error) {
296+
if len(children) != 0 {
297+
return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), 0)
298+
}
299+
return &r, nil
300+
}
301+
158302
// ProjectRow evaluates a set of projections.
159303
func ProjectRow(
160304
ctx *sql.Context,

sql/type.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ type Type interface {
104104
// NullType represents the type of NULL values
105105
type NullType interface {
106106
Type
107+
108+
// IsNullType is a marker interface for types that represent NULL values.
109+
IsNullType() bool
107110
}
108111

109112
// DeferredType is a placeholder for prepared statements

sql/types/conversion.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -642,10 +642,10 @@ func GeneralizeTypes(a, b sql.Type) sql.Type {
642642
return a
643643
}
644644

645-
if a == Null {
645+
if IsNullType(a) {
646646
return b
647647
}
648-
if b == Null {
648+
if IsNullType(b) {
649649
return a
650650
}
651651

@@ -722,6 +722,16 @@ func GeneralizeTypes(a, b sql.Type) sql.Type {
722722
if IsNumber(a) && IsNumber(b) {
723723
return generalizeNumberTypes(a, b)
724724
}
725+
726+
if IsText(a) && IsText(b) {
727+
sta := a.(sql.StringType)
728+
stb := b.(sql.StringType)
729+
if sta.Length() > stb.Length() {
730+
return a
731+
}
732+
return b
733+
}
734+
725735
// TODO: decide if we want to make this VarChar to match MySQL, match VarChar length to max of two types
726736
return LongText
727737
}

sql/types/conversion_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ func TestGeneralizeTypes(t *testing.T) {
208208
{Date, Date, Date},
209209
{Date, Timestamp, DatetimeMaxPrecision},
210210
{Timestamp, Timestamp, Timestamp},
211+
{Timestamp, TimestampMaxPrecision, TimestampMaxPrecision},
211212
{Timestamp, Datetime, DatetimeMaxPrecision},
212213
{Null, Int64, Int64},
213214
{Null, Null, Null},

sql/types/null.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ var (
3434

3535
type nullType struct{}
3636

37+
func (t nullType) IsNullType() bool {
38+
return true
39+
}
40+
3741
// Compare implements Type interface. Note that while this returns 0 (equals)
3842
// for ordering purposes, in SQL NULL != NULL.
3943
func (t nullType) Compare(s context.Context, a interface{}, b interface{}) (int, error) {

0 commit comments

Comments
 (0)