@@ -2,12 +2,17 @@ package forward
22
33import (
44 "context"
5- "fmt"
65 "github.com/spf13/cobra"
6+ "net/http"
7+ "os"
8+ "os/signal"
9+ "path/filepath"
710 "pbench/log"
811 "pbench/presto"
912 "pbench/utils"
1013 "sync"
14+ "sync/atomic"
15+ "syscall"
1116 "time"
1217)
1318
@@ -17,49 +22,141 @@ var (
1722 RunName string
1823 PollInterval time.Duration
1924
20- runningTasks sync.WaitGroup
25+ runningTasks sync.WaitGroup
26+ failedToForward atomic.Uint32
27+ forwarded atomic.Uint32
2128)
2229
23- type QueryHistory struct {
24- QueryId string `presto:"query_id"`
25- Query string `presto:"query"`
26- Created * time.Time `presto:"created"`
27- }
28-
2930func Run (_ * cobra.Command , _ []string ) {
30- //OutputPath = filepath.Join(OutputPath, RunName)
31- //utils.PrepareOutputDirectory(OutputPath)
32- //
33- //// also start to write logs to the output directory from this point on.
34- //logPath := filepath.Join(OutputPath, "forward.log")
35- //flushLog := utils.InitLogFile(logPath)
36- //defer flushLog()
31+ OutputPath = filepath .Join (OutputPath , RunName )
32+ utils .PrepareOutputDirectory (OutputPath )
3733
38- prestoClusters := PrestoFlagsArray .Assemble ()
34+ // also start to write logs to the output directory from this point on.
35+ logPath := filepath .Join (OutputPath , "forward.log" )
36+ flushLog := utils .InitLogFile (logPath )
37+ defer flushLog ()
38+
39+ ctx , cancel := context .WithCancel (context .Background ())
40+ timeToExit := make (chan os.Signal , 1 )
41+ signal .Notify (timeToExit , syscall .SIGINT , syscall .SIGTERM , syscall .SIGQUIT )
42+ // Handle SIGINT, SIGTERM, and SIGQUIT. When ctx is canceled, in-progress MySQL transactions and InfluxDB operations will roll back.
43+ go func () {
44+ sig := <- timeToExit
45+ if sig != nil {
46+ log .Info ().Msg ("abort forwarding" )
47+ cancel ()
48+ }
49+ }()
50+
51+ prestoClusters := PrestoFlagsArray .Pivot ()
3952 // The design here is to forward the traffic from cluster 0 to the rest.
4053 sourceClusterSize := 0
4154 clients := make ([]* presto.Client , 0 , len (prestoClusters ))
4255 for i , cluster := range prestoClusters {
4356 clients = append (clients , cluster .NewPrestoClient ())
44- if stats , _ , err := clients [i ].GetClusterInfo (context .Background ()); err != nil {
45- log .Fatal ().Err (err ).Msgf ("cannot connect to cluster at position %d" , i )
57+ // Check if we can connect to the cluster.
58+ if stats , _ , err := clients [i ].GetClusterInfo (ctx ); err != nil {
59+ log .Fatal ().Err (err ).Msgf ("cannot connect to cluster at position %d: %s" , i , cluster .ServerUrl )
4660 } else if i == 0 {
4761 sourceClusterSize = stats .ActiveWorkers
4862 } else if stats .ActiveWorkers != sourceClusterSize {
49- log .Warn ().Msgf ("source cluster size does not match target cluster %d size (%d != %d)" , i , stats .ActiveWorkers , sourceClusterSize )
63+ log .Warn ().Msgf ("the source cluster and target cluster %d do not match in size (%d != %d)" , i , sourceClusterSize , stats .ActiveWorkers )
5064 }
5165 }
5266
5367 sourceClient := clients [0 ]
5468 trueValue := true
55- states , _ , err := sourceClient .GetQueryState (context .Background (), & presto.GetQueryStatsOptions {
56- IncludeAllQueries : & trueValue ,
57- IncludeAllQueryProgressStats : nil ,
58- ExcludeResourceGroupPathInfo : nil ,
59- QueryTextSizeLimit : nil ,
60- })
61- if err != nil {
62- log .Fatal ().Err (err ).Msgf ("cannot get query states" )
69+ // lastQueryStateCheckCutoffTime is the query create time of the most recent query in the previous batch.
70+ // We only look at queries created later than this timestamp in the following batch.
71+ lastQueryStateCheckCutoffTime := time.Time {}
72+ firstBatch := true
73+ // Keep running until the source cluster becomes unavailable or the user interrupts or quits using Ctrl + C or Ctrl + D.
74+ for ctx .Err () == nil {
75+ states , _ , err := sourceClient .GetQueryState (ctx , & presto.GetQueryStatsOptions {IncludeAllQueries : & trueValue })
76+ if err != nil {
77+ log .Error ().Err (err ).Msgf ("failed to get query states" )
78+ break
79+ }
80+ newCutoffTime := time.Time {}
81+ for _ , state := range states {
82+ if ! state .CreateTime .After (lastQueryStateCheckCutoffTime ) {
83+ // We looked at this query in the previous batch.
84+ continue
85+ }
86+ if newCutoffTime .Before (state .CreateTime ) {
87+ newCutoffTime = state .CreateTime
88+ }
89+ if ! firstBatch {
90+ runningTasks .Add (1 )
91+ go forwardQuery (ctx , & state , clients )
92+ }
93+ }
94+ firstBatch = false
95+ if newCutoffTime .After (lastQueryStateCheckCutoffTime ) {
96+ lastQueryStateCheckCutoffTime = newCutoffTime
97+ }
98+ timer := time .NewTimer (PollInterval )
99+ select {
100+ case <- ctx .Done ():
101+ case <- timer .C :
102+ }
103+ }
104+ runningTasks .Wait ()
105+ // This causes the signal handler to exit.
106+ close (timeToExit )
107+ log .Info ().Uint32 ("forwarded" , forwarded .Load ()).Uint32 ("failed_to_forward" , failedToForward .Load ()).
108+ Msgf ("finished forwarding queries" )
109+ }
110+
111+ func forwardQuery (ctx context.Context , queryState * presto.QueryStateInfo , clients []* presto.Client ) {
112+ defer runningTasks .Done ()
113+ queryInfo , _ , queryInfoErr := clients [0 ].GetQueryInfo (ctx , queryState .QueryId , false , nil )
114+ if queryInfoErr != nil {
115+ log .Error ().Str ("query_id" , queryState .QueryId ).Err (queryInfoErr ).Msg ("failed to get query info for forwarding" )
116+ failedToForward .Add (1 )
117+ return
118+ }
119+ SessionPropertyHeader := clients [0 ].GenerateSessionParamsHeaderValue (queryInfo .Session .CollectSessionProperties ())
120+ successful , failed := atomic.Uint32 {}, atomic.Uint32 {}
121+ forwardedQueries := sync.WaitGroup {}
122+ for i := 1 ; i < len (clients ); i ++ {
123+ forwardedQueries .Add (1 )
124+ go func (client * presto.Client ) {
125+ defer forwardedQueries .Done ()
126+ clientResult , _ , queryErr := client .Query (ctx , queryInfo .Query , func (req * http.Request ) {
127+ if queryInfo .Session .Catalog != nil {
128+ req .Header .Set (presto .CatalogHeader , * queryInfo .Session .Catalog )
129+ }
130+ if queryInfo .Session .Schema != nil {
131+ req .Header .Set (presto .SchemaHeader , * queryInfo .Session .Schema )
132+ }
133+ req .Header .Set (presto .SessionHeader , SessionPropertyHeader )
134+ req .Header .Set (presto .SourceHeader , queryInfo .QueryId )
135+ })
136+ if queryErr != nil {
137+ log .Error ().Str ("source_query_id" , queryInfo .QueryId ).
138+ Str ("target_host" , client .GetHost ()).Err (queryErr ).Msg ("failed to execute query" )
139+ failed .Add (1 )
140+ return
141+ }
142+ rowCount := 0
143+ drainErr := clientResult .Drain (ctx , func (qr * presto.QueryResults ) error {
144+ rowCount += len (qr .Data )
145+ return nil
146+ })
147+ if drainErr != nil {
148+ log .Error ().Str ("source_query_id" , queryInfo .QueryId ).
149+ Str ("target_host" , client .GetHost ()).Err (drainErr ).Msg ("failed to fetch query result" )
150+ failed .Add (1 )
151+ return
152+ }
153+ successful .Add (1 )
154+ log .Info ().Str ("source_query_id" , queryInfo .QueryId ).
155+ Str ("target_host" , client .GetHost ()).Int ("row_count" , rowCount ).Msg ("query executed successfully" )
156+ }(clients [i ])
63157 }
64- fmt .Printf ("%#v" , states )
158+ forwardedQueries .Wait ()
159+ log .Info ().Str ("source_query_id" , queryInfo .QueryId ).Uint32 ("successful" , successful .Load ()).
160+ Uint32 ("failed" , failed .Load ()).Msg ("query forwarding finished" )
161+ forwarded .Add (1 )
65162}
0 commit comments