Skip to content

Commit 49ebad7

Browse files
committed
fix(planner): avoid schema null panics in slice mode
Signed-off-by: Jiyong Huang <huangjy@emqx.io>
1 parent b8272ae commit 49ebad7

File tree

3 files changed

+83
-14
lines changed

3 files changed

+83
-14
lines changed

internal/topo/operator/project_operator.go

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ type ProjectOp struct {
5555
// compiledExprs caches pre-compiled accessors for ExprFields.
5656
// For simple FieldRef fields, isDirect=true and we skip the AST walk.
5757
compiledExprs []compiledField
58+
aliasIndices []int
59+
fieldIndices []int
5860
}
5961

6062
// compiledField caches whether an ExprField is a simple direct lookup
@@ -178,6 +180,14 @@ func (pp *ProjectOp) getVE(tuple xsql.RawRow, agg xsql.AggregateData, wr *xsql.W
178180
}
179181
pp.compiledExprs = append(pp.compiledExprs, cf)
180182
}
183+
pp.aliasIndices = make([]int, len(pp.AliasFields))
184+
for i, f := range pp.AliasFields {
185+
pp.aliasIndices[i] = getExprIndex(f.Expr)
186+
}
187+
pp.fieldIndices = make([]int, len(pp.Fields))
188+
for i, f := range pp.Fields {
189+
pp.fieldIndices[i] = getExprIndex(f.Expr)
190+
}
181191
}
182192

183193
pp.wv.Data = tuple
@@ -217,33 +227,32 @@ func (pp *ProjectOp) getRowVE(tuple xsql.Row, wr *xsql.WindowRange, fv *xsql.Fun
217227
func (pp *ProjectOp) project(row xsql.RawRow, ve *xsql.ValuerEval) error {
218228
switch rt := row.(type) {
219229
case *xsql.SliceTuple:
220-
for _, f := range pp.AliasFields {
230+
for i, f := range pp.AliasFields {
221231
vi := ve.Eval(f.Expr)
222232
if e, ok := vi.(error); ok {
223233
return fmt.Errorf("expr: %s meet error, err:%v", f.Expr.String(), e)
224234
}
225235
if pp.SendNil && vi == nil {
226-
// set it to a typed nil to distinguish from nil
227-
// so that the encoder can treat it differently from nil
228236
vi = cast.TNil
229237
}
230-
fr := f.Expr.(*ast.FieldRef)
231-
rt.SetByIndex(fr.Index, vi)
238+
index := pp.aliasIndices[i]
239+
if index >= 0 {
240+
rt.SetByIndex(index, vi)
241+
}
232242
}
233-
for _, f := range pp.Fields {
243+
for i, f := range pp.Fields {
234244
if f.AName == "" {
235245
vi := ve.Eval(f.Expr)
236246
if e, ok := vi.(error); ok {
237247
return fmt.Errorf("expr: %s meet error, err:%v", f.Expr.String(), e)
238248
}
239-
// TODO deal with other types
240249
if pp.SendNil && vi == nil {
241-
// set it to a typed nil to distinguish from nil
242-
// so that the encoder can treat it differently from nil
243250
vi = cast.TNil
244251
}
245-
fr := f.Expr.(*ast.FieldRef)
246-
rt.SetByIndex(fr.Index, vi)
252+
index := pp.fieldIndices[i]
253+
if index >= 0 {
254+
rt.SetByIndex(index, vi)
255+
}
247256
}
248257
}
249258
rt.Compact(pp.FieldLen)
@@ -305,3 +314,18 @@ func (pp *ProjectOp) project(row xsql.RawRow, ve *xsql.ValuerEval) error {
305314
}
306315
return nil
307316
}
317+
318+
func getExprIndex(expr ast.Expr) int {
319+
if fr, ok := expr.(*ast.FieldRef); ok {
320+
return fr.Index
321+
}
322+
var targetIdx int = -1
323+
ast.WalkFunc(expr, func(n ast.Node) bool {
324+
if fr, ok := n.(*ast.FieldRef); ok {
325+
targetIdx = fr.Index
326+
return false // stop walking
327+
}
328+
return true
329+
})
330+
return targetIdx
331+
}

internal/topo/operator/project_test.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3358,3 +3358,36 @@ func TestSliceNilField(t *testing.T) {
33583358
})
33593359
}
33603360
}
3361+
3362+
3363+
func TestProjectOp_SliceTuple_ComplexExpression(t *testing.T) {
3364+
sql := "SELECT abs(a) FROM test"
3365+
stmt, err := xsql.NewParser(strings.NewReader(sql)).Parse()
3366+
require.NoError(t, err)
3367+
3368+
pp := &ProjectOp{}
3369+
parseStmtWithSlice(pp, stmt.Fields, true)
3370+
3371+
// In current parseStmtWithSlice, inner ast.FieldRef elements of ast.Call
3372+
// don't get HasIndex=true automatically. Let's fix that for this test.
3373+
ast.WalkFunc(stmt.Fields[0].Expr, func(n ast.Node) bool {
3374+
if fr, ok := n.(*ast.FieldRef); ok {
3375+
fr.HasIndex = true
3376+
fr.SourceIndex = constSourceIndex[fr.Name]
3377+
}
3378+
return true
3379+
})
3380+
3381+
fv, afv := xsql.NewFunctionValuersForOp(nil)
3382+
3383+
sliceTuple := &xsql.SliceTuple{
3384+
SourceContent: make([]interface{}, 5),
3385+
}
3386+
sliceTuple.SourceContent[constSourceIndex["a"]] = int64(-42)
3387+
3388+
ve := pp.getRowVE(sliceTuple, nil, fv, afv)
3389+
err = pp.project(sliceTuple, ve)
3390+
require.NoError(t, err)
3391+
3392+
require.Equal(t, int64(42), sliceTuple.SourceContent[0])
3393+
}

internal/topo/planner/planner.go

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,12 @@ func updateFieldIndex(ctx api.StreamContext, stmt *ast.SelectStatement, af []*as
108108
case *ast.FieldRef:
109109
if nf.IsColumn() {
110110
sc := schema.GetStreamSchemaIndex(string(nf.StreamName))
111+
if sc == nil {
112+
streamsFromStmt := xsql.GetStreams(stmt)
113+
if len(streamsFromStmt) == 1 {
114+
sc = schema.GetStreamSchemaIndex(streamsFromStmt[0])
115+
}
116+
}
111117
if sc != nil {
112118
if si, ok := sc[nf.Name]; ok {
113119
nf.SourceIndex = si
@@ -148,10 +154,10 @@ func updateFieldIndex(ctx api.StreamContext, stmt *ast.SelectStatement, af []*as
148154
}
149155
index := len(stmt.Fields)
150156
for _, fieldExpr := range fieldExprs {
151-
index = doUpdateIndex(ctx, fieldExpr, index, aliasIndex)
157+
index = doUpdateIndex(ctx, stmt, fieldExpr, index, aliasIndex)
152158
}
153159
// Add sink index for other non-select fields
154-
index = doUpdateIndex(ctx, stmt, index, aliasIndex)
160+
index = doUpdateIndex(ctx, stmt, stmt, index, aliasIndex)
155161
ctx.GetLogger().Infof("assign %d field index", index)
156162
// Set temp index for analytic funcs
157163
index = 0
@@ -166,13 +172,19 @@ func updateFieldIndex(ctx api.StreamContext, stmt *ast.SelectStatement, af []*as
166172
ctx.GetLogger().Infof("assign %d temp index", index)
167173
}
168174

169-
func doUpdateIndex(ctx api.StreamContext, root ast.Node, index int, aliasIndex map[string]int) int {
175+
func doUpdateIndex(ctx api.StreamContext, stmt *ast.SelectStatement, root ast.Node, index int, aliasIndex map[string]int) int {
170176
ast.WalkFunc(root, func(n ast.Node) bool {
171177
switch nf := n.(type) {
172178
case *ast.FieldRef:
173179
nf.HasIndex = true
174180
if nf.IsColumn() {
175181
sc := schema.GetStreamSchemaIndex(string(nf.StreamName))
182+
if sc == nil {
183+
streamsFromStmt := xsql.GetStreams(stmt)
184+
if len(streamsFromStmt) == 1 {
185+
sc = schema.GetStreamSchemaIndex(streamsFromStmt[0])
186+
}
187+
}
176188
if sc != nil {
177189
if si, ok := sc[nf.Name]; ok {
178190
nf.SourceIndex = si

0 commit comments

Comments
 (0)