Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 47 additions & 30 deletions sdk/data/azcosmos/cosmos_container_query_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ func (c *ContainerClient) executeQueryWithEngine(queryEngine queryengine.QueryEn

// runEngineRequests concurrently executes per-partition QueryRequests for either query or readMany pipelines.
// prepareFn returns the query text, parameters, and a drain flag for each request.
// It serializes ProvideData calls through a single goroutine to preserve ordering guarantees required by the pipeline.
// Collects all results and calls ProvideData once with a single batch to reduce CGo overhead.
func runEngineRequests(
ctx context.Context,
c *ContainerClient,
Expand All @@ -192,32 +192,16 @@ func runEngineRequests(
requests []queryengine.QueryRequest,
concurrency int,
prepareFn func(req queryengine.QueryRequest) (query string, params []QueryParameter, drain bool),
) (totalCharge float32, err error) {
) (float32, error) {
if len(requests) == 0 {
return 0, nil
}

jobs := make(chan queryengine.QueryRequest, len(requests))
provideCh := make(chan []queryengine.QueryResult)
resultsCh := make(chan queryengine.QueryResult)
errCh := make(chan error, 1)
done := make(chan struct{})
providerDone := make(chan struct{})
var wg sync.WaitGroup
var chargeMu sync.Mutex

// Provider goroutine ensures only one ProvideData executes at a time.
go func() {
defer close(providerDone)
for batch := range provideCh {
if perr := pipeline.ProvideData(batch); perr != nil {
select {
case errCh <- perr:
default:
}
return
}
}
}()

// Adjust concurrency.
workerCount := concurrency
Expand All @@ -228,9 +212,25 @@ func runEngineRequests(
workerCount = 1
}

// Per-worker request charge slots (no lock needed)
charges := make([]float32, workerCount)

// Collector goroutine gathers all results
var allResults []queryengine.QueryResult
var resultsMu sync.Mutex
collectorDone := make(chan struct{})
go func() {
defer close(collectorDone)
for result := range resultsCh {
resultsMu.Lock()
allResults = append(allResults, result)
resultsMu.Unlock()
}
}()

for w := 0; w < workerCount; w++ {
wg.Add(1)
go func() {
go func(workerIndex int) {
defer wg.Done()
for {
select {
Expand Down Expand Up @@ -274,9 +274,7 @@ func runEngineRequests(
}
return
}
chargeMu.Lock()
totalCharge += qResp.RequestCharge
chargeMu.Unlock()
charges[workerIndex] += qResp.RequestCharge

// Load the data into a buffer to send it to the pipeline
buf := new(bytes.Buffer)
Expand All @@ -302,12 +300,12 @@ func runEngineRequests(
select {
case <-done:
return
case provideCh <- []queryengine.QueryResult{result}:
case resultsCh <- result:
}
}
}
}
}()
}(w)
}

// Feed jobs
Expand All @@ -323,8 +321,17 @@ func runEngineRequests(
close(jobs)
}()

// Close provider after workers finish
go func() { wg.Wait(); close(provideCh) }()
// Close results channel after workers finish
go func() { wg.Wait(); close(resultsCh) }()

// Helper to sum charges
sumCharges := func() float32 {
var total float32
for _, charge := range charges {
total += charge
}
return total
}

// Wait for completion / error / cancellation
select {
Expand All @@ -334,15 +341,25 @@ func runEngineRequests(
default:
close(done)
}
return totalCharge, e
return sumCharges(), e
case <-ctx.Done():
select {
case <-done:
default:
close(done)
}
return totalCharge, ctx.Err()
case <-providerDone:
return sumCharges(), ctx.Err()
case <-collectorDone:
}

// Sum up all worker charges
totalCharge := sumCharges()

// Provide all collected results in a single batch
if len(allResults) > 0 {
if err := pipeline.ProvideData(allResults); err != nil {
return totalCharge, err
}
}

return totalCharge, nil
Expand Down
Loading