@@ -19,6 +19,7 @@ import (
1919 "github.com/cockroachdb/cockroach/pkg/security/username"
2020 "github.com/cockroachdb/cockroach/pkg/server"
2121 "github.com/cockroachdb/cockroach/pkg/server/apiconstants"
22+ "github.com/cockroachdb/cockroach/pkg/server/authserver"
2223 "github.com/cockroachdb/cockroach/pkg/server/serverpb"
2324 "github.com/cockroachdb/cockroach/pkg/server/srvtestutils"
2425 "github.com/cockroachdb/cockroach/pkg/settings/cluster"
@@ -1540,6 +1541,93 @@ func TestCombinedStatementUsesCorrectSourceTable(t *testing.T) {
15401541 }
15411542}
15421543
1544+ func TestDrainSqlStats (t * testing.T ) {
1545+ defer leaktest .AfterTest (t )()
1546+ defer log .Scope (t ).Close (t )
1547+ appName := "drain_stats_app"
1548+ testCluster := serverutils .StartCluster (t , 3 , base.TestClusterArgs {})
1549+ ctx := context .Background ()
1550+ defer testCluster .Stopper ().Stop (ctx )
1551+
1552+ conn1 := sqlutils .MakeSQLRunner (testCluster .ServerConn (0 ))
1553+ conn2 := sqlutils .MakeSQLRunner (testCluster .ServerConn (1 ))
1554+ conn3 := sqlutils .MakeSQLRunner (testCluster .ServerConn (2 ))
1555+
1556+ for _ , conn := range []* sqlutils.SQLRunner {conn1 , conn2 , conn3 } {
1557+ conn .Exec (t , fmt .Sprintf ("SET application_name = '%s'" , appName ))
1558+ conn .Exec (t , "SELECT 1" )
1559+ }
1560+
1561+ statusServer := testCluster .Server (0 ).GetStatusClient (t )
1562+ resp , err := statusServer .DrainSqlStats (ctx , & serverpb.DrainSqlStatsRequest {})
1563+ require .NoError (t , err )
1564+ checkFingerprintCount (t , resp )
1565+ stmts , txns := filterStatementStatsByAppName (resp .Statements , resp .Transactions , appName )
1566+ require .Len (t , stmts , 1 )
1567+ require .Equal (t , int64 (3 ), stmts [0 ].Stats .Count )
1568+ require .Equal (t , "SELECT _" , stmts [0 ].Key .Query )
1569+ require .Len (t , txns , 1 )
1570+ require .Equal (t , int64 (3 ), txns [0 ].Stats .Count )
1571+ require .Len (t , txns [0 ].StatementFingerprintIDs , 1 )
1572+ require .Equal (t , stmts [0 ].ID , txns [0 ].StatementFingerprintIDs [0 ])
1573+
1574+ // Check that the stats are cleared.
1575+ resp , err = statusServer .DrainSqlStats (ctx , & serverpb.DrainSqlStatsRequest {})
1576+ require .NoError (t , err )
1577+ stmts , txns = filterStatementStatsByAppName (resp .Statements , resp .Transactions , appName )
1578+ require .Empty (t , stmts )
1579+ require .Empty (t , txns )
1580+ }
1581+
1582+ func TestDrainSqlStats_partialOutage (t * testing.T ) {
1583+ defer leaktest .AfterTest (t )()
1584+ defer log .Scope (t ).Close (t )
1585+ appName := "drain_stats_app"
1586+ testCluster := serverutils .StartCluster (t , 3 , base.TestClusterArgs {})
1587+ ctx := context .Background ()
1588+ defer testCluster .Stopper ().Stop (ctx )
1589+
1590+ conn1 := sqlutils .MakeSQLRunner (testCluster .ServerConn (0 ))
1591+ conn2 := sqlutils .MakeSQLRunner (testCluster .ServerConn (1 ))
1592+ conn3 := sqlutils .MakeSQLRunner (testCluster .ServerConn (2 ))
1593+
1594+ for _ , conn := range []* sqlutils.SQLRunner {conn1 , conn2 , conn3 } {
1595+ conn .Exec (t , fmt .Sprintf ("SET application_name = '%s'" , appName ))
1596+ conn .Exec (t , "SELECT 1" )
1597+ }
1598+
1599+ // Stop server 2 to simulate a partial outage
1600+ testCluster .StopServer (2 )
1601+ statusServer := testCluster .Server (0 ).GetStatusClient (t )
1602+ resp , err := statusServer .DrainSqlStats (ctx , & serverpb.DrainSqlStatsRequest {})
1603+ require .NoError (t , err )
1604+ checkFingerprintCount (t , resp )
1605+ stmts , txns := filterStatementStatsByAppName (resp .Statements , resp .Transactions , appName )
1606+ require .Len (t , stmts , 1 )
1607+ require .Equal (t , int64 (2 ), stmts [0 ].Stats .Count )
1608+ require .Equal (t , "SELECT _" , stmts [0 ].Key .Query )
1609+ require .Len (t , txns , 1 )
1610+ require .Equal (t , int64 (2 ), txns [0 ].Stats .Count )
1611+ }
1612+
1613+ func TestDrainSqlStatsPermissionDenied (t * testing.T ) {
1614+ defer leaktest .AfterTest (t )()
1615+ defer log .Scope (t ).Close (t )
1616+ ts := serverutils .StartServerOnly (t , base.TestServerArgs {})
1617+ ctx := context .Background ()
1618+ nonRootUser := apiconstants .TestingUserNameNoAdmin ()
1619+ sqlutils .MakeSQLRunner (ts .SQLConn (t )).Exec (t , fmt .Sprintf ("CREATE USER IF NOT EXISTS %s" , nonRootUser ))
1620+ ctx = authserver .ContextWithHTTPAuthInfo (ctx , nonRootUser .Normalized (), 1 )
1621+ ctx = authserver .ForwardHTTPAuthInfoToRPCCalls (ctx , nil )
1622+
1623+ statusClient := ts .GetStatusClient (t )
1624+ defer ts .Stopper ().Stop (ctx )
1625+ _ , err := statusClient .DrainSqlStats (ctx , & serverpb.DrainSqlStatsRequest {})
1626+
1627+ require .Error (t , err )
1628+ require .Contains (t , err .Error (), "user does not have admin role" )
1629+ }
1630+
15431631func createStmtFetchMode (
15441632 sort serverpb.StatsSortOptions ,
15451633) * serverpb.CombinedStatementsStatsRequest_FetchMode {
@@ -1860,3 +1948,44 @@ VALUES (
18601948 }
18611949 db .Exec (t , query , args ... )
18621950}
1951+
1952+ func filterStatementStatsByAppName (
1953+ statements []* appstatspb.CollectedStatementStatistics ,
1954+ transactions []* appstatspb.CollectedTransactionStatistics ,
1955+ appName string ,
1956+ ) ([]* appstatspb.CollectedStatementStatistics , []* appstatspb.CollectedTransactionStatistics ) {
1957+ var filteredStatements []* appstatspb.CollectedStatementStatistics
1958+ var filteredTransactions []* appstatspb.CollectedTransactionStatistics
1959+
1960+ for _ , stmt := range statements {
1961+ if stmt .Key .App == appName {
1962+ filteredStatements = append (filteredStatements , stmt )
1963+ }
1964+ }
1965+
1966+ for _ , txn := range transactions {
1967+ if txn .App == appName {
1968+ filteredTransactions = append (filteredTransactions , txn )
1969+ }
1970+ }
1971+
1972+ return filteredStatements , filteredTransactions
1973+ }
1974+
1975+ func checkFingerprintCount (t * testing.T , resp * serverpb.DrainStatsResponse ) {
1976+ stmtFingerprints := make (map [appstatspb.StmtFingerprintID ]struct {})
1977+ txnFingerprints := make (map [appstatspb.TransactionFingerprintID ]struct {})
1978+ for _ , stmt := range resp .Statements {
1979+ if _ , ok := stmtFingerprints [stmt .ID ]; ! ok {
1980+ stmtFingerprints [stmt .ID ] = struct {}{}
1981+ }
1982+ }
1983+
1984+ for _ , txn := range resp .Transactions {
1985+ if _ , ok := txnFingerprints [txn .TransactionFingerprintID ]; ! ok {
1986+ txnFingerprints [txn .TransactionFingerprintID ] = struct {}{}
1987+ }
1988+ }
1989+ actualFpCount := len (stmtFingerprints ) + len (txnFingerprints )
1990+ require .Equal (t , resp .FingerprintCount , int64 (actualFpCount ))
1991+ }
0 commit comments