Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/round/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ func processRoundDecimalFile(inputPath string) (err error) {
if err != nil {
return err
}
bufWriter = bufio.NewWriter(outputFile)
bufWriter = bufio.NewWriterSize(outputFile, 8192)
colCount = len(cols)
} else {
// no decimal column, quick exit
Expand Down
59 changes: 59 additions & 0 deletions stage/map.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ func ParseStage(stage *Stage, stages Map) (*Stage, error) {
}
}
stages[stage.Id] = stage
err := processStreams(stage, stages)
if err != nil {
return nil, fmt.Errorf("failed to process streams for stage %s: %w", stage.Id, err)
}

for _, nextStagePath := range stage.NextStagePaths {
if nextStage, err := ParseStageFromFile(nextStagePath, stages); err != nil {
return nil, err
Expand Down Expand Up @@ -150,3 +155,57 @@ func checkStageLinks(stage *Stage) error {
}
return nil
}

func processStreams(stage *Stage, stages Map) error {
if len(stage.Streams) == 0 {
stage.seed = stage.States.RandSeed
return nil
}

for _, spec := range stage.Streams {
if spec.StreamCount <= 0 {
return fmt.Errorf("stream_count must be positive, got %d for stream %s", spec.StreamCount, spec.StreamPath)
}

if len(spec.Seeds) > 0 {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not using your Validate() method? The code seems duplicated.

if len(spec.Seeds) != 1 && len(spec.Seeds) != spec.StreamCount {
return fmt.Errorf("seeds array length (%d) must be either 1 or equal to stream_count (%d) for stream %s",
len(spec.Seeds), spec.StreamCount, spec.StreamPath)
}
stage.States.RandSeed = 0 // Disable random seed generation when custom seeds are provided
}

streamPath, err := spec.GetValidatedPath(stage.BaseDir)
if err != nil {
return err
}
for i := 0; i < spec.StreamCount; i++ {
streamStage, err := ReadStageFromFile(streamPath)
if err != nil {
return fmt.Errorf("failed to read stream file %s: %w", streamPath, err)
}

// Set unique ID for this stream instance
baseId := fileNameWithoutPathAndExt(streamPath)
streamStage.Id = fmt.Sprintf("%s_stream_%d", baseId, i+1)

// Set custom seed if configured
if seed, hasCustomSeed := spec.GetSeedForInstance(i); hasCustomSeed {
streamStage.seed = seed
log.Info().Str("stream_stage", streamStage.Id).Int64("custom_seed", seed).Int("instance", i+1).Msg("stream assigned custom seed")
} else {
// No seed configured, use stage's RandSeed + instance offset
streamStage.seed = stage.States.RandSeed + int64(i-1)
log.Info().Str("stream_stage", streamStage.Id).Int64("generated_seed", streamStage.seed).Int64("base_seed", stage.States.RandSeed).Int("instance", i+1).Msg("stream assigned generated seed")
}

stages[streamStage.Id] = streamStage
stage.NextStages = append(stage.NextStages, streamStage)
streamStage.wgPrerequisites.Add(1)
}
}

stage.Streams = nil

return nil
}
11 changes: 7 additions & 4 deletions stage/mysql_run_recorder.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ import (
"context"
"database/sql"
_ "embed"
_ "github.com/go-sql-driver/mysql"
"pbench/log"
"pbench/utils"

_ "github.com/go-sql-driver/mysql"
)

var (
Expand Down Expand Up @@ -65,7 +66,7 @@ VALUES (?, ?, ?, 0, 0, 0, ?)`

func (m *MySQLRunRecorder) RecordQuery(_ context.Context, s *Stage, result *QueryResult) {
recordNewQuery := `INSERT INTO pbench_queries (run_id, stage_id, query_file, query_index, query_id, sequence_no,
cold_run, succeeded, start_time, end_time, row_count, expected_row_count, duration_ms, info_url) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`
cold_run, succeeded, start_time, end_time, row_count, expected_row_count, duration_ms, info_url, seed) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update pbench_queries_ddl.sql with the new column

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And why do we need a seed in this table?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two tables, pbench_runs and pbench_queries. Currently pbench_runs has a seed column, but with this additional functionality to be able to seed each stream, there had to be some way to add reporting for multiple seeds in a run. The simplest way I found to do it was to add 'seed' as a column to pbench_queries and group by stage_id to be able to present it in Grafana.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was looking for other ways to report seed per stream, please let me know if you have any suggestions.

var queryFile string
if result.Query.File != nil {
queryFile = *result.Query.File
Expand All @@ -83,11 +84,13 @@ cold_run, succeeded, start_time, end_time, row_count, expected_row_count, durati
result.RowCount, sql.NullInt32{
Int32: int32(result.Query.ExpectedRowCount),
Valid: result.Query.ExpectedRowCount >= 0,
}, result.Duration.Milliseconds(), result.InfoUrl)
}, result.Duration.Milliseconds(), result.InfoUrl, result.Seed)
log.Info().Str("stage_id", result.StageId).Stringer("start_time", result.StartTime).Stringer("end_time", result.EndTime).
Str("info_url", result.InfoUrl).Int64("seed", result.Seed).Msg("recorded query result to MySQL")
if err != nil {
log.Error().EmbedObject(result).Err(err).Msg("failed to send query summary to MySQL")
}
updateRunInfo := `UPDATE pbench_runs SET start_time = ?, queries_ran = queries_ran + 1, failed = ?, mismatch = ? WHERE run_id = ?`
updateRunInfo := `UPDATE pbench_runs SET start_time = ?, queries_ran = queries_ran + 1, failed = ? , mismatch = ? WHERE run_id = ?`
res, err := m.db.Exec(updateRunInfo, s.States.RunStartTime, m.failed, m.mismatch, m.runId)
if err != nil {
log.Error().Err(err).Str("run_name", s.States.RunName).Int64("run_id", m.runId).
Expand Down
4 changes: 3 additions & 1 deletion stage/result.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package stage

import (
"github.com/rs/zerolog"
"pbench/log"
"time"

"github.com/rs/zerolog"
)

type QueryResult struct {
StageId string
Seed int64
Query *Query
QueryId string
InfoUrl string
Expand Down
102 changes: 83 additions & 19 deletions stage/stage.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ type Stage struct {
// Use RandomlyExecuteUntil to specify a duration like "1h" or an integer as the number of queries should be executed
// before exiting.
RandomlyExecuteUntil *string `json:"randomly_execute_until,omitempty"`
// If NoRandomDuplicates is set to true, queries will not be repeated during random execution
// until all queries have been executed once. After that, the selection pool resets if more
// executions are needed.
NoRandomDuplicates *bool `json:"no_random_duplicates,omitempty"`
// If not set, the default is 1. The default value is set when the stage is run.
ColdRuns *int `json:"cold_runs,omitempty" validate:"omitempty,gte=0"`
// If not set, the default is 0.
Expand All @@ -87,6 +91,9 @@ type Stage struct {
// knob was not set to true.
SaveJson *bool `json:"save_json,omitempty"`
NextStagePaths []string `json:"next,omitempty"`
// StreamSpecs allows specifying streams to launch dynamically with custom counts and seeds
// Format: [{"stream_file_path": "path/to/stream.json", "stream_count": 5, "seeds": [123, 456]}]
Streams []Streams `json:"streams,omitempty"`

// BaseDir is set to the directory path of this stage's location. It is used to locate the descendant stages when
// their locations are specified using relative paths. It is not possible to set this in a stage definition json file.
Expand All @@ -101,6 +108,11 @@ type Stage struct {
// Client is by default passed down to descendant stages.
Client *presto.Client `json:"-"`

// Stream instance information for custom seeding and identification
// Descendant stages will **NOT** inherit this value from their parents so this is declared as a value not a pointer.
// Custom seed for this stage instance, nil if using default seeding
seed int64 `json:"-"`

// Convenient access to the expected row count array under the current schema.
expectedRowCountInCurrentSchema []int
// Convenient access to the catalog, schema, and timezone
Expand Down Expand Up @@ -150,6 +162,7 @@ func (s *Stage) Run(ctx context.Context) int {

go func() {
s.States.wgExitMainStage.Wait()
close(s.States.resultChan)
// wgExitMainStage goes down to 0 after all the goroutines finish. Then we exit the driver by
// closing the timeToExit channel, which will trigger the graceful shutdown process -
// (flushing the log file, writing the final time log summary, etc.).
Expand All @@ -174,20 +187,30 @@ func (s *Stage) Run(ctx context.Context) int {

for {
select {
case result := <-s.States.resultChan:
case result, ok := <-s.States.resultChan:
if !ok {
// resultChan closed: all results received, finalize and exit
s.States.RunFinishTime = time.Now()
for _, recorder := range s.States.runRecorders {
recorder.RecordRun(utils.GetCtxWithTimeout(time.Second*5), s, results)
}
return int(s.States.exitCode.Load())
}
results = append(results, result)
for _, recorder := range s.States.runRecorders {
recorder.RecordQuery(utils.GetCtxWithTimeout(time.Second*5), s, result)
}
case sig := <-timeToExit:
if sig != nil {
// Cancel the context and wait for the goroutines to exit.
s.States.AbortAll(fmt.Errorf(sig.String()))
case sig, ok := <-timeToExit:
if !ok {
// timeToExit channel closed, no more signals — continue to receive results
continue
}
s.States.RunFinishTime = time.Now()
for _, recorder := range s.States.runRecorders {
recorder.RecordRun(utils.GetCtxWithTimeout(time.Second*5), s, results)
if sig != nil {
// Received shutdown signal; cancel ongoing queries
log.Info().Msgf("Shutdown signal received: %v. Aborting queries...", sig)
s.States.AbortAll(fmt.Errorf("%s", sig.String()))
// Keep receiving results until resultChan is closed
}
return int(s.States.exitCode.Load())
}
Expand Down Expand Up @@ -237,8 +260,11 @@ func (s *Stage) run(ctx context.Context) (returnErr error) {
if preStageErr != nil {
return fmt.Errorf("pre-stage script execution failed: %w", preStageErr)
}
if len(s.Queries)+len(s.QueryFiles) > 0 {
if len(s.Queries)+len(s.QueryFiles)+len(s.Streams) > 0 {
if *s.RandomExecution {
if s.RandomlyExecuteUntil == nil {
return fmt.Errorf("randomly_execute_until must be set for random execution in stage %s", s.Id)
}
returnErr = s.runRandomly(ctx)
} else {
returnErr = s.runSequentially(ctx)
Expand Down Expand Up @@ -343,21 +369,57 @@ func (s *Stage) runRandomly(ctx context.Context) error {
return nil
}
}
r := rand.New(rand.NewSource(s.States.RandSeed))

r := rand.New(rand.NewSource(s.seed))
log.Info().Str("stream_id", s.Id).Int64("custom_seed", s.seed).Msg("initialized with seed")
s.States.RandSeedUsed = true
log.Info().Int64("seed", s.States.RandSeed).Msg("random source seeded")
randIndexUpperBound := len(s.Queries) + len(s.QueryFiles)
for i := 1; continueExecution(i); i++ {
idx := r.Intn(randIndexUpperBound)
if i <= s.States.RandSkip {
if i == s.States.RandSkip {
log.Info().Msgf("skipped %d random selections", i)

totalQueries := len(s.Queries) + len(s.QueryFiles)

// refreshIndices generates a new set of random indices for selecting queries.
// If NoRandomDuplicates is set to true, it generates a shuffled list of all indices.
// Otherwise, it generates a list of random indices with possible duplicates.
refreshIndices := func() []int {
indices := make([]int, totalQueries)
if s.NoRandomDuplicates != nil && *s.NoRandomDuplicates {
for i := 0; i < totalQueries; i++ {
indices[i] = i
}
r.Shuffle(len(indices), func(i, j int) {
indices[i], indices[j] = indices[j], indices[i]
})
} else {
for i := 0; i < totalQueries; i++ {
indices[i] = r.Intn(totalQueries)
}
}
return indices
}

executionCount := 1
var currentIndices []int
var indexPosition int

for continueExecution(executionCount) {
// Refresh indices when all queries have been used
if currentIndices == nil || indexPosition >= len(currentIndices) {
currentIndices = refreshIndices()
indexPosition = 0
}

idx := currentIndices[indexPosition]
indexPosition++

if executionCount <= s.States.RandSkip {
if executionCount == s.States.RandSkip {
log.Info().Msgf("skipped %d random selections", executionCount)
}
executionCount++
continue
}

if idx < len(s.Queries) {
// Run query embedded in the json file.
pseudoFileName := fmt.Sprintf("rand_%d", i)
pseudoFileName := fmt.Sprintf("rand_%d", executionCount)
if err := s.runQueries(ctx, s.Queries[idx:idx+1], &pseudoFileName, 0); err != nil {
return err
}
Expand All @@ -367,11 +429,12 @@ func (s *Stage) runRandomly(ctx context.Context) error {
if relPath, relErr := filepath.Rel(s.BaseDir, queryFile); relErr == nil {
fileAlias = relPath
}
fileAlias = fmt.Sprintf("rand_%d_%s", i, fileAlias)
fileAlias = fmt.Sprintf("rand_%d_%s", executionCount, fileAlias)
if err := s.runQueryFile(ctx, queryFile, nil, &fileAlias); err != nil {
return err
}
}
executionCount++
}
log.Info().Msg("random execution concluded.")
return nil
Expand Down Expand Up @@ -476,6 +539,7 @@ func (s *Stage) runQuery(ctx context.Context, query *Query) (result *QueryResult

result = &QueryResult{
StageId: s.Id,
Seed: s.seed,
Query: query,
StartTime: time.Now(),
}
Expand Down
10 changes: 10 additions & 0 deletions stage/stage_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ func (s *Stage) MergeWith(other *Stage) *Stage {
if other.RandomExecution != nil {
s.RandomExecution = other.RandomExecution
}
if other.NoRandomDuplicates != nil {
s.NoRandomDuplicates = other.NoRandomDuplicates
}
if other.RandomlyExecuteUntil != nil {
s.RandomlyExecuteUntil = other.RandomlyExecuteUntil
}
Expand All @@ -92,6 +95,7 @@ func (s *Stage) MergeWith(other *Stage) *Stage {
}
s.NextStagePaths = append(s.NextStagePaths, other.NextStagePaths...)
s.BaseDir = other.BaseDir
s.Streams = append(s.Streams, other.Streams...)

s.PreStageShellScripts = append(s.PreStageShellScripts, other.PreStageShellScripts...)
s.PostQueryShellScripts = append(s.PostQueryShellScripts, other.PostQueryShellScripts...)
Expand Down Expand Up @@ -194,6 +198,9 @@ func (s *Stage) setDefaults() {
if s.RandomExecution == nil {
s.RandomExecution = &falseValue
}
if s.NoRandomDuplicates == nil {
s.NoRandomDuplicates = &falseValue
}
if s.AbortOnError == nil {
s.AbortOnError = &falseValue
}
Expand Down Expand Up @@ -235,6 +242,9 @@ func (s *Stage) propagateStates() {
if nextStage.RandomExecution == nil {
nextStage.RandomExecution = s.RandomExecution
}
if nextStage.NoRandomDuplicates == nil {
nextStage.NoRandomDuplicates = s.NoRandomDuplicates
}
if nextStage.RandomlyExecuteUntil == nil {
nextStage.RandomlyExecuteUntil = s.RandomlyExecuteUntil
}
Expand Down
Loading