Skip to content

Commit 5e66aff

Browse files
committed
workload: enable workload_generator to run SQL workloads with init artifacts
This patch wires up the `workload_generator` command to consume the outputs from `workload init`—the schema YAML and generated SQL files—and drive end‑to‑end data‑driven query execution: • Load and unmarshal the schema YAML produced during initialization • Discover `<db>_read.sql` and `<db>_write.sql` files emitted by the SQL generator • Apply the existing placeholderRewriter at runtime to inject generated data into each statement • Execute the rewritten SQL workload against a CockroachDB cluster, generating rows and collecting metrics Fixes: CRDB‑51752 Release note (cli change): Enables `workload_generator` to read init‑time schema and SQL artifacts and run SQL workloads with placeholder‑driven data generation.
1 parent 19fcdb6 commit 5e66aff

File tree

7 files changed

+666
-77
lines changed

7 files changed

+666
-77
lines changed

pkg/workload/workload_generator/BUILD.bazel

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ go_library(
55
srcs = [
66
"column_generator.go",
77
"constants.go",
8+
"run_utils.go",
89
"schema_designs.go",
910
"schema_generator.go",
1011
"schema_utils.go",
@@ -23,10 +24,13 @@ go_library(
2324
"//pkg/sql/sem/tree/treecmp",
2425
"//pkg/sql/types",
2526
"//pkg/util/bufalloc",
27+
"//pkg/util/syncutil",
2628
"//pkg/util/timeutil",
2729
"//pkg/workload",
2830
"//pkg/workload/histogram",
31+
"@com_github_cockroachdb_cockroach_go_v2//crdb",
2932
"@com_github_cockroachdb_errors//:errors",
33+
"@com_github_lib_pq//:pq",
3034
"@com_github_spf13_pflag//:pflag",
3135
"@in_gopkg_yaml_v2//:yaml_v2",
3236
],

pkg/workload/workload_generator/constants.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,6 @@ const (
1818
sqlForeignKey = "FOREIGN KEY" // sqlForeignKey is a constant string used to represent foreign key constraints in column definitions
1919
sqlFamily = "FAMILY" // sqlFamily is a constant string used to represent family definitions in column definitions
2020
sqlConstraint = "CONSTRAINT" // sqlConstraint is a constant string used to represent constraint definitions in column definitions
21+
insert = "INSERT" // insert is a constant string used to represent INSERT statements in SQL
22+
update = "UPDATE" // update is a constant string used to represent UPDATE statements in SQL
2123
)
Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
// Copyright 2025 The Cockroach Authors.
2+
//
3+
// Use of this software is governed by the CockroachDB Software License
4+
// included in the /LICENSE file.
5+
6+
package workload_generator
7+
8+
import (
9+
"bufio"
10+
gosql "database/sql"
11+
"fmt"
12+
"io"
13+
"os"
14+
"regexp"
15+
"strconv"
16+
"strings"
17+
)
18+
19+
var (
20+
// txnRe defines how to separate the transactions from the <dbName.sql> file
21+
txnRe = regexp.MustCompile(`(?m)^-------Begin Transaction------\s*([\s\S]*?)-------End Transaction-------`)
22+
// placeholderRe defines the structure of the placeholders that need to be converted into $1, $2...
23+
placeholderRe = regexp.MustCompile(`"?:-:\|(.+?)\|:-:"?`)
24+
)
25+
26+
// setColumnValue sets the value for a placeholder in the args slice at index i.
27+
func setColumnValue(raw string, placeholder Placeholder, args []interface{}, i int) error {
28+
var arg interface{}
29+
// If we got an empty string and this column is nullable, emitting a SQL NULL.
30+
if raw == "" && placeholder.IsNullable {
31+
arg = setNullType(placeholder)
32+
} else {
33+
// Otherwise the raw string is parsed into the right Go/sql type.
34+
typedValue, err := setNotNullType(placeholder, raw)
35+
if err != nil {
36+
return err
37+
}
38+
arg = typedValue
39+
}
40+
41+
args[i] = arg
42+
return nil
43+
}
44+
45+
// setNotNullType converts the raw string value to the appropriate SQL type based on the placeholder's column type.
46+
func setNotNullType(placeholder Placeholder, raw string) (interface{}, error) {
47+
var arg interface{}
48+
switch t := strings.ToUpper(placeholder.ColType); {
49+
case strings.HasPrefix(t, "INT"): // integer types
50+
iv, err := strconv.ParseInt(raw, 10, 64)
51+
if err != nil {
52+
return nil, err
53+
}
54+
arg = gosql.NullInt64{Int64: iv, Valid: true}
55+
case strings.HasPrefix(t, "FLOAT"), strings.HasPrefix(t, "DECIMAL"), strings.HasPrefix(t, "NUMERIC"), strings.HasPrefix(t, "DOUBLE"): // floating point types
56+
fv, err := strconv.ParseFloat(raw, 64)
57+
if err != nil {
58+
return nil, err
59+
}
60+
arg = gosql.NullFloat64{Float64: fv, Valid: true}
61+
case t == "BOOL", t == "BOOLEAN": // boolean types
62+
bv, err := strconv.ParseBool(raw)
63+
if err != nil {
64+
return nil, err
65+
}
66+
arg = gosql.NullBool{Bool: bv, Valid: true}
67+
// The remaining types are parsed as raw strings.
68+
default:
69+
// Everything else is treated as text/varchar/etc.
70+
arg = gosql.NullString{String: raw, Valid: raw != ""}
71+
}
72+
return arg, nil
73+
}
74+
75+
// setNullType sets the argument to a SQL NULL value based on the column type of the placeholder.
76+
func setNullType(placeholder Placeholder) interface{} {
77+
var arg interface{}
78+
switch t := strings.ToUpper(placeholder.ColType); {
79+
case strings.HasPrefix(t, "INT"):
80+
arg = gosql.NullInt64{Valid: false}
81+
case strings.HasPrefix(t, "FLOAT"), strings.HasPrefix(t, "DECIMAL"), strings.HasPrefix(t, "NUMERIC"), strings.HasPrefix(t, "DOUBLE"):
82+
arg = gosql.NullFloat64{Valid: false}
83+
case t == "BOOL", t == "BOOLEAN":
84+
arg = gosql.NullBool{Valid: false}
85+
default:
86+
arg = gosql.NullString{Valid: false}
87+
}
88+
return arg
89+
}
90+
91+
// getTableName checks if the name field of placeholder is a column in the TableName column in the allSchema map inside d.
92+
// If yes, then returns that table name. otherwise looks for a table with that column and returns that.
93+
func getTableName(p Placeholder, d *workloadGenerator) string {
94+
for _, block := range d.workloadSchema[p.TableName] {
95+
for colName := range block.Columns {
96+
if colName == p.Name {
97+
return p.TableName
98+
}
99+
}
100+
}
101+
for tableName, blocks := range d.workloadSchema {
102+
block := blocks[0]
103+
for colName := range block.Columns {
104+
if colName == p.Name {
105+
return tableName
106+
}
107+
}
108+
}
109+
return p.TableName
110+
}
111+
112+
// getColumnValue retrieves the value for a placeholder based on its clause and whether it has a foreign key dependency.
113+
func getColumnValue(
114+
allPksAreFK bool,
115+
p Placeholder,
116+
d *workloadGenerator,
117+
inserted map[string][]interface{},
118+
indexes []int,
119+
i int,
120+
) string {
121+
var raw string
122+
if allPksAreFK && (p.Clause == insert || p.Clause == update) {
123+
tableName := getTableName(p, d)
124+
key := fmt.Sprintf("%s.%s", tableName, p.Name)
125+
fk := d.columnGens[key].columnMeta.FK
126+
parts := strings.Split(fk, ".")
127+
parentCol := parts[len(parts)-1] // The last part is the column name.
128+
if vals, ok := inserted[parentCol]; ok && len(vals) > 0 {
129+
raw = vals[0].(string) // Using the first value from the inserted map.
130+
inserted[parentCol] = vals[1:] // Remove the first value from the map.
131+
} else {
132+
//Fallback that shouldn't really happen.
133+
raw = d.getRegularColumnValue(p, indexes[i])
134+
}
135+
} else {
136+
raw = d.getRegularColumnValue(p, indexes[i])
137+
}
138+
return raw
139+
}
140+
141+
// checkIfAllPkAreFk checks if all primary keys in the SQL query are foreign keys.
142+
func checkIfAllPkAreFk(sqlQuery SQLQuery, d *workloadGenerator) bool {
143+
allPksAreFK := true
144+
for _, p := range sqlQuery.Placeholders {
145+
tableName := getTableName(p, d)
146+
key := fmt.Sprintf("%s.%s", tableName, p.Name)
147+
if p.IsPrimaryKey && !d.columnGens[key].columnMeta.HasForeignKey {
148+
allPksAreFK = false
149+
break
150+
}
151+
}
152+
return allPksAreFK
153+
}
154+
155+
// readSQL reads <dbName><read/write>.sql and returns a slice of Transactions.
156+
// It will number placeholders $1…$N separately in each SQL statement.
157+
func readSQL(path, typ string) ([]Transaction, error) {
158+
f, err := os.Open(path)
159+
if err != nil {
160+
return nil, fmt.Errorf("open file: %w", err)
161+
}
162+
defer func() {
163+
_ = f.Close()
164+
}()
165+
166+
data, err := io.ReadAll(bufio.NewReader(f))
167+
if err != nil {
168+
return nil, fmt.Errorf("read file: %w", err)
169+
}
170+
text := string(data)
171+
172+
// Each transaction block
173+
blocks := txnRe.FindAllStringSubmatch(text, -1)
174+
175+
// Currently we are defining two types of transactions - read and write.
176+
var txns []Transaction
177+
// For every transaction block.
178+
for _, blk := range blocks {
179+
body := blk[1]
180+
lines := strings.Split(body, "\n")
181+
var txn Transaction
182+
var curr SQLQuery
183+
// For every sql query line.
184+
for _, line := range lines {
185+
line = strings.TrimSpace(line)
186+
if line == "" || line == "BEGIN;" || line == "COMMIT;" {
187+
continue
188+
}
189+
// build up the SQL text
190+
curr.SQL += line + " "
191+
if strings.HasSuffix(line, ";") {
192+
// once we hit the end of a statement, re-number from $1
193+
stmtPos := 1
194+
var placeholders []Placeholder
195+
// sqlOut is the rewritten SQL where the placeholders have been replaced with $x.
196+
sqlOut := placeholderRe.ReplaceAllStringFunc(curr.SQL, getPlaceholderReplacer(&placeholders, &stmtPos))
197+
198+
curr.SQL = strings.TrimSpace(sqlOut)
199+
curr.Placeholders = placeholders
200+
txn.Queries = append(txn.Queries, curr)
201+
202+
// reset for next statement
203+
curr = SQLQuery{}
204+
}
205+
}
206+
// Decide whether the transaction is a read type or write type.
207+
txn.typ = typ
208+
txns = append(txns, txn)
209+
}
210+
211+
return txns, nil
212+
}
213+
214+
// getPlaceholderReplacer returns a ReplaceAllStringFunc that
215+
// appends to placeholders and increments stmtPos.
216+
func getPlaceholderReplacer(placeholders *[]Placeholder, stmtPos *int) func(string) string {
217+
return func(match string) string {
218+
inner := placeholderRe.FindStringSubmatch(match)[1]
219+
parts := splitQuoted(inner)
220+
221+
var p Placeholder
222+
//Set all the fields in the placeholder struct based on the information from the sql.
223+
p.Name = trimQuotes(parts[0])
224+
p.ColType = trimQuotes(parts[1])
225+
p.IsNullable = trimQuotes(parts[2]) == "NULL"
226+
p.IsPrimaryKey = strings.Contains(trimQuotes(parts[3]), "PRIMARY KEY")
227+
if d := trimQuotes(parts[4]); d != "" {
228+
p.Default = &d
229+
}
230+
p.IsUnique = trimQuotes(parts[5]) == "UNIQUE"
231+
if fk := trimQuotes(parts[6]); fk != "" {
232+
fkParts := strings.Split(strings.TrimPrefix(fk, "FK→"), ".")
233+
p.FKReference = &FKRef{Table: fkParts[0], Column: fkParts[1]}
234+
}
235+
if chk := trimQuotes(parts[7]); chk != "" {
236+
p.InlineCheck = &chk
237+
}
238+
p.Clause = trimQuotes(parts[8])
239+
p.Position = *stmtPos
240+
p.TableName = trimQuotes(parts[9])
241+
242+
*stmtPos++
243+
*placeholders = append(*placeholders, p)
244+
return fmt.Sprintf("$%d::%s", p.Position, p.ColType)
245+
}
246+
}
247+
248+
// splitQuoted splits a string like "'a','b','c'" into ["'a'", "'b'", "'c'"].
249+
func splitQuoted(s string) []string {
250+
var out []string
251+
buf := ""
252+
inQuote := false
253+
for _, r := range s {
254+
switch r {
255+
case '\'':
256+
inQuote = !inQuote
257+
buf += string(r)
258+
case ',':
259+
if inQuote {
260+
buf += string(r)
261+
} else {
262+
out = append(out, strings.TrimSpace(buf))
263+
buf = ""
264+
}
265+
default:
266+
buf += string(r)
267+
}
268+
}
269+
if buf != "" {
270+
out = append(out, strings.TrimSpace(buf))
271+
}
272+
return out
273+
}
274+
275+
// trimQuotes removes leading/trailing single-quotes.
276+
func trimQuotes(s string) string {
277+
return strings.Trim(s, "'")
278+
}

pkg/workload/workload_generator/sql_generator.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,13 @@ func generateWorkload(
120120
// 5d) The sql query is processed to replace _ and __more__ with new placeholders which contain information about the column they refer to.
121121
rewritten, err := replacePlaceholders(rawSQL, allSchemas)
122122
if err != nil {
123-
f.Close()
123+
if errClose := f.Close(); errClose != nil {
124+
// Wrap the original placeholder-rewrite error, then attach the close error.
125+
return errors.WithSecondaryError(
126+
errors.Wrapf(err, "rewriting SQL %q", rawSQL),
127+
errClose,
128+
)
129+
}
124130
return errors.Wrapf(err, "rewriting SQL %q", rawSQL)
125131
}
126132

pkg/workload/workload_generator/sql_utils.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -803,19 +803,25 @@ func writeTransaction(
803803
if err != nil {
804804
return errors.Wrapf(err, "creating %s", outPathRead)
805805
}
806-
defer outReadFile.Close()
806+
defer func() {
807+
_ = outReadFile.Close()
808+
}()
807809

808810
outWriteFile, err := os.Create(outPathWrite)
809811
if err != nil {
810812
return errors.Wrapf(err, "creating %s", outPathWrite)
811813
}
812-
defer outWriteFile.Close()
814+
defer func() {
815+
_ = outWriteFile.Close()
816+
}()
813817

814818
outUnhandledFile, err := os.Create(outPathUnhandled)
815819
if err != nil {
816820
return errors.Wrapf(err, "creating %s", outPathUnhandled)
817821
}
818-
defer outUnhandledFile.Close()
822+
defer func() {
823+
_ = outUnhandledFile.Close()
824+
}()
819825

820826
for _, txnID := range txnOrder {
821827
stmts := txnMap[txnID]

pkg/workload/workload_generator/types.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ func (u *UniqueWrapper) Next() string {
214214
// accept it
215215
u.seen[v] = struct{}{}
216216
u.order = append(u.order, v)
217-
// evict oldest if over capacity
217+
// evict oldest value if over capacity
218218
if len(u.order) > u.capacity {
219219
oldest := u.order[0]
220220
u.order = u.order[1:]

0 commit comments

Comments
 (0)