@@ -265,3 +265,58 @@ func TestQuit(t *testing.T) {
265265 require .Equal (t , 1 , strings .Count (data , "# Cmd_type: Quit" ))
266266 require .Equal (t , uint64 (3 ), cpt .capturedCmds )
267267}
268+
269+ func TestFilterCmds (t * testing.T ) {
270+ tests := []struct {
271+ packet []byte
272+ want string
273+ notWant string
274+ }{
275+ {
276+ packet : pnet .MakeChangeUser (& pnet.ChangeUserReq {
277+ User : "root" ,
278+ DB : "test" ,
279+ }, 0 ),
280+ want : pnet .ComResetConnection .String (),
281+ notWant : pnet .ComChangeUser .String (),
282+ },
283+ {
284+ packet : append ([]byte {pnet .ComQuery .Byte ()}, []byte ("CREATE USER u1 IDENTIFIED BY '123456'" )... ),
285+ notWant : "123456" ,
286+ },
287+ {
288+ packet : append ([]byte {pnet .ComQuery .Byte ()}, []byte ("select 1" )... ),
289+ want : "select 1" ,
290+ },
291+ }
292+
293+ cfg := CaptureConfig {
294+ Output : t .TempDir (),
295+ Duration : 10 * time .Second ,
296+ }
297+ for i , test := range tests {
298+ cpt := NewCapture (zap .NewNop ())
299+ writer := newMockWriter (store.WriterCfg {})
300+ cfg .cmdLogger = writer
301+ require .NoError (t , cpt .Start (cfg ))
302+ cpt .Capture (test .packet , time .Now (), 100 , func () (string , error ) {
303+ return "init session 100" , nil
304+ })
305+ cpt .Stop (nil )
306+
307+ data := string (writer .getData ())
308+ if len (test .want ) > 0 {
309+ require .Equal (t , 1 , strings .Count (data , test .want ), "case %d" , i )
310+ require .Equal (t , uint64 (2 ), cpt .capturedCmds , "case %d" , i )
311+ require .Equal (t , uint64 (0 ), cpt .filteredCmds , "case %d" , i )
312+ } else {
313+ require .Equal (t , uint64 (1 ), cpt .capturedCmds , "case %d" , i )
314+ require .Equal (t , uint64 (1 ), cpt .filteredCmds , "case %d" , i )
315+ }
316+ if len (test .notWant ) > 0 {
317+ require .Equal (t , 0 , strings .Count (data , test .notWant ), "case %d" , i )
318+ }
319+
320+ cpt .Close ()
321+ }
322+ }
0 commit comments