|
1 | | -// Copyright 2023 Dolthub, Inc. |
| 1 | +// Copyright 2023-2024 Dolthub, Inc. |
2 | 2 | // |
3 | 3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | 4 | // you may not use this file except in compliance with the License. |
|
15 | 15 | package sqle |
16 | 16 |
|
17 | 17 | import ( |
| 18 | + "context" |
18 | 19 | "testing" |
19 | 20 | "time" |
20 | 21 |
|
21 | 22 | "github.com/dolthub/vitess/go/vt/proto/query" |
22 | 23 | "github.com/stretchr/testify/require" |
23 | 24 |
|
| 25 | + "github.com/dolthub/go-mysql-server/memory" |
24 | 26 | "github.com/dolthub/go-mysql-server/sql" |
| 27 | + "github.com/dolthub/go-mysql-server/sql/analyzer" |
25 | 28 | "github.com/dolthub/go-mysql-server/sql/expression" |
| 29 | + "github.com/dolthub/go-mysql-server/sql/plan" |
| 30 | + "github.com/dolthub/go-mysql-server/sql/rowexec" |
26 | 31 | "github.com/dolthub/go-mysql-server/sql/types" |
| 32 | + "github.com/dolthub/go-mysql-server/sql/variables" |
27 | 33 | ) |
28 | 34 |
|
29 | 35 | func TestBindingsToExprs(t *testing.T) { |
@@ -145,3 +151,93 @@ func TestBindingsToExprs(t *testing.T) { |
145 | 151 | }) |
146 | 152 | } |
147 | 153 | } |
| 154 | + |
| 155 | +// wrapper around sql.Table to make it not indexable |
| 156 | +type nonIndexableTable struct { |
| 157 | + *memory.Table |
| 158 | +} |
| 159 | + |
| 160 | +var _ memory.MemTable = (*nonIndexableTable)(nil) |
| 161 | + |
| 162 | +func (t *nonIndexableTable) IgnoreSessionData() bool { |
| 163 | + return true |
| 164 | +} |
| 165 | + |
| 166 | +func getRuleFrom(rules []analyzer.Rule, id analyzer.RuleId) *analyzer.Rule { |
| 167 | + for _, rule := range rules { |
| 168 | + if rule.Id == id { |
| 169 | + return &rule |
| 170 | + } |
| 171 | + } |
| 172 | + |
| 173 | + return nil |
| 174 | +} |
| 175 | + |
| 176 | +// TODO: this was an analyzer test, but we don't have a mock process list for it to use, so it has to be here |
| 177 | +func TestTrackProcess(t *testing.T) { |
| 178 | + require := require.New(t) |
| 179 | + variables.InitStatusVariables() |
| 180 | + db := memory.NewDatabase("db") |
| 181 | + provider := memory.NewDBProvider(db) |
| 182 | + a := analyzer.NewDefault(provider) |
| 183 | + sess := memory.NewSession(sql.NewBaseSession(), provider) |
| 184 | + |
| 185 | + node := plan.NewInnerJoin( |
| 186 | + plan.NewResolvedTable(&nonIndexableTable{memory.NewPartitionedTable(db.BaseDatabase, "foo", sql.PrimaryKeySchema{}, nil, 2)}, nil, nil), |
| 187 | + plan.NewResolvedTable(memory.NewPartitionedTable(db.BaseDatabase, "bar", sql.PrimaryKeySchema{}, nil, 4), nil, nil), |
| 188 | + expression.NewLiteral(int64(1), types.Int64), |
| 189 | + ) |
| 190 | + |
| 191 | + pl := NewProcessList() |
| 192 | + |
| 193 | + ctx := sql.NewContext(context.Background(), sql.WithPid(1), sql.WithProcessList(pl), sql.WithSession(sess)) |
| 194 | + pl.AddConnection(ctx.Session.ID(), "localhost") |
| 195 | + pl.ConnectionReady(ctx.Session) |
| 196 | + ctx, err := ctx.ProcessList.BeginQuery(ctx, "SELECT foo") |
| 197 | + require.NoError(err) |
| 198 | + |
| 199 | + rule := getRuleFrom(analyzer.OnceAfterAll, analyzer.TrackProcessId) |
| 200 | + result, _, err := rule.Apply(ctx, a, node, nil, analyzer.DefaultRuleSelector, nil) |
| 201 | + require.NoError(err) |
| 202 | + |
| 203 | + processes := ctx.ProcessList.Processes() |
| 204 | + require.Len(processes, 1) |
| 205 | + require.Equal("SELECT foo", processes[0].Query) |
| 206 | + require.Equal( |
| 207 | + map[string]sql.TableProgress{ |
| 208 | + "foo": { |
| 209 | + Progress: sql.Progress{Name: "foo", Done: 0, Total: 2}, |
| 210 | + PartitionsProgress: map[string]sql.PartitionProgress{}, |
| 211 | + }, |
| 212 | + "bar": { |
| 213 | + Progress: sql.Progress{Name: "bar", Done: 0, Total: 4}, |
| 214 | + PartitionsProgress: map[string]sql.PartitionProgress{}, |
| 215 | + }, |
| 216 | + }, |
| 217 | + processes[0].Progress) |
| 218 | + |
| 219 | + join, ok := result.(*plan.JoinNode) |
| 220 | + require.True(ok) |
| 221 | + require.Equal(plan.JoinTypeInner, join.JoinType()) |
| 222 | + |
| 223 | + lhs, ok := join.Left().(*plan.ResolvedTable) |
| 224 | + require.True(ok) |
| 225 | + _, ok = lhs.Table.(*plan.ProcessTable) |
| 226 | + require.True(ok) |
| 227 | + |
| 228 | + rhs, ok := join.Right().(*plan.ResolvedTable) |
| 229 | + require.True(ok) |
| 230 | + _, ok = rhs.Table.(*plan.ProcessTable) |
| 231 | + require.True(ok) |
| 232 | + |
| 233 | + iter, err := rowexec.DefaultBuilder.Build(ctx, result, nil) |
| 234 | + iter = finalizeIters(ctx, result, nil, iter) |
| 235 | + require.NoError(err) |
| 236 | + _, err = sql.RowIterToRows(ctx, iter) |
| 237 | + require.NoError(err) |
| 238 | + |
| 239 | + processes = ctx.ProcessList.Processes() |
| 240 | + require.Len(processes, 1) |
| 241 | + require.Equal(sql.ProcessCommandSleep, processes[0].Command) |
| 242 | + require.Error(ctx.Err()) |
| 243 | +} |
0 commit comments