@@ -2,7 +2,6 @@ package forward
22
33import (
44 "context"
5- "github.com/spf13/cobra"
65 "net/http"
76 "os"
87 "os/signal"
@@ -16,6 +15,8 @@ import (
1615 "sync/atomic"
1716 "syscall"
1817 "time"
18+
19+ "github.com/spf13/cobra"
1920)
2021
2122var (
3536 runningTasks sync.WaitGroup
3637 failedToForward atomic.Uint32
3738 forwarded atomic.Uint32
39+ //This is the cache map to cache the running queries
40+ //the key is the queryId in source cluster. The values are the queries running on target clusters,
41+ //it includes, nextUri and the pointer to the target client.
42+ runningQueriesCacheMap = make (map [string ]* []QueryCacheEntry )
3843)
3944
40- const maxRetry = 10
45+ const (
46+ maxRetry = 10
47+ queryStateErrorCancelled = "USER_CANCELED"
48+ queryStateFailed = "FAILED"
49+ )
50+
51+ type QueryCacheEntry struct {
52+ NextUri string
53+ Client * presto.Client
54+ }
4155
4256func waitForNextPoll (ctx context.Context ) {
4357 timer := time .NewTimer (PollInterval )
@@ -142,6 +156,10 @@ func Run(_ *cobra.Command, _ []string) {
142156 // to process get filtered out.
143157 newCutoffTime := lastQueryStateCheckCutoffTime
144158 for _ , state := range states {
159+ //check if there is query in cancel status
160+ if state .QueryState == queryStateFailed && state .ErrorCode .Name == queryStateErrorCancelled {
161+ go checkAndCancelQuery (ctx , & state )
162+ }
145163 if ! state .CreateTime .After (lastQueryStateCheckCutoffTime ) {
146164 // We looked at this query in the previous batch.
147165 continue
@@ -171,6 +189,19 @@ func Run(_ *cobra.Command, _ []string) {
171189 Msgf ("finished forwarding queries" )
172190}
173191
192+ func checkAndCancelQuery (ctx context.Context , queryState * presto.QueryStateInfo ) {
193+ if queryCacheEntries , ok := runningQueriesCacheMap [queryState .QueryId ]; ok {
194+ for _ , q := range * queryCacheEntries {
195+ if q .NextUri != "" {
196+ _ , _ , cancelQueryErr := q .Client .CancelQuery (ctx , q .NextUri )
197+ if cancelQueryErr != nil {
198+ log .Error ().Msgf ("cancel query failed on target cluter: %s error: %s" , q .NextUri , cancelQueryErr .Error ())
199+ }
200+ }
201+ }
202+ }
203+ }
204+
174205func forwardQuery (ctx context.Context , queryState * presto.QueryStateInfo , clients []* presto.Client ) {
175206 defer runningTasks .Done ()
176207 var (
@@ -226,6 +257,7 @@ func forwardQuery(ctx context.Context, queryState *presto.QueryStateInfo, client
226257 }
227258 successful , failed := atomic.Uint32 {}, atomic.Uint32 {}
228259 forwardedQueries := sync.WaitGroup {}
260+ cachedQueries := make ([]QueryCacheEntry , len (clients )- 1 )
229261 for i := 1 ; i < len (clients ); i ++ {
230262 forwardedQueries .Add (1 )
231263 go func (client * presto.Client ) {
@@ -246,6 +278,10 @@ func forwardQuery(ctx context.Context, queryState *presto.QueryStateInfo, client
246278 failed .Add (1 )
247279 return
248280 }
281+ //build cache for running query
282+ if clientResult .NextUri != nil {
283+ cachedQueries [i - 1 ] = QueryCacheEntry {NextUri : * clientResult .NextUri , Client : client }
284+ }
249285 rowCount := 0
250286 drainErr := clientResult .Drain (ctx , func (qr * presto.QueryResults ) error {
251287 rowCount += len (qr .Data )
@@ -262,8 +298,14 @@ func forwardQuery(ctx context.Context, queryState *presto.QueryStateInfo, client
262298 Str ("target_host" , client .GetHost ()).Int ("row_count" , rowCount ).Msg ("query executed successfully" )
263299 }(clients [i ])
264300 }
301+ //Add running query into to cache
302+ runningQueriesCacheMap [queryState .QueryId ] = & cachedQueries
303+ log .Info ().Msg ("adding query to cache" + queryState .QueryId )
265304 forwardedQueries .Wait ()
266305 log .Info ().Str ("source_query_id" , queryInfo .QueryId ).Uint32 ("successful" , successful .Load ()).
267306 Uint32 ("failed" , failed .Load ()).Msg ("query forwarding finished" )
268307 forwarded .Add (1 )
308+ //remove finished query from cache
309+ delete (runningQueriesCacheMap , queryState .QueryId )
310+ log .Info ().Msg ("removing query from cache" + queryState .QueryId )
269311}
0 commit comments