Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions internal/component/database_observability/lexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,6 @@ import (
"github.com/DataDog/go-sqllexer"
)

// ExtractTableNames extracts the table names from a SQL query
func ExtractTableNames(sql string) ([]string, error) {
normalizer := sqllexer.NewNormalizer(
sqllexer.WithCollectTables(true),
)
_, metadata, err := normalizer.Normalize(sql, sqllexer.WithDBMS(sqllexer.DBMSPostgres))
if err != nil {
return nil, fmt.Errorf("failed to normalize SQL: %w", err)
}

// Return all table names, including those that end with "..." for truncated queries, as we can't know if the table name was truncated or not
return metadata.Tables, nil
}

// RedactSql obfuscates a SQL query by replacing literals with ? placeholders
func RedactSql(sql string) string {
obfuscatedSql := sqllexer.NewObfuscator().Obfuscate(sql)
Expand Down
164 changes: 18 additions & 146 deletions internal/component/database_observability/lexer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,26 +49,26 @@ func TestPgSqlParser_Redact(t *testing.T) {
sql: `WITH active_users AS (
SELECT * FROM users WHERE last_login > '2024-01-01'
), recent_orders AS (
SELECT o.* FROM orders o
JOIN active_users u ON u.id = o.user_id
SELECT o.* FROM orders o
JOIN active_users u ON u.id = o.user_id
WHERE o.created_at > '2024-03-01'
)
SELECT au.name, COUNT(ro.id) as order_count
FROM active_users au
LEFT JOIN recent_orders ro ON ro.user_id = au.id
GROUP BY au.name
SELECT au.name, COUNT(ro.id) as order_count
FROM active_users au
LEFT JOIN recent_orders ro ON ro.user_id = au.id
GROUP BY au.name
HAVING COUNT(ro.id) > 5`,
want: `WITH active_users AS (
SELECT * FROM users WHERE last_login > ?
), recent_orders AS (
SELECT o.* FROM orders o
JOIN active_users u ON u.id = o.user_id
SELECT o.* FROM orders o
JOIN active_users u ON u.id = o.user_id
WHERE o.created_at > ?
)
SELECT au.name, COUNT(ro.id) as order_count
FROM active_users au
LEFT JOIN recent_orders ro ON ro.user_id = au.id
GROUP BY au.name
SELECT au.name, COUNT(ro.id) as order_count
FROM active_users au
LEFT JOIN recent_orders ro ON ro.user_id = au.id
GROUP BY au.name
HAVING COUNT(ro.id) > ?`,
},
{
Expand Down Expand Up @@ -106,13 +106,13 @@ func TestPgSqlParser_Redact(t *testing.T) {
{
name: "WITH statement with UPDATE",
sql: `WITH inactive_users AS (
SELECT id FROM users
SELECT id FROM users
WHERE last_login < '2023-01-01' AND status = 'active'
)
UPDATE users SET status = 'inactive', updated_at = '2024-03-20'
WHERE id IN (SELECT id FROM inactive_users)`,
want: `WITH inactive_users AS (
SELECT id FROM users
SELECT id FROM users
WHERE last_login < ? AND status = ?
)
UPDATE users SET status = ?, updated_at = ?
Expand All @@ -121,16 +121,16 @@ func TestPgSqlParser_Redact(t *testing.T) {
{
name: "WITH statement with DELETE",
sql: `WITH old_orders AS (
SELECT id FROM orders
SELECT id FROM orders
WHERE created_at < '2023-01-01' AND status = 'completed'
)
DELETE FROM order_items
DELETE FROM order_items
WHERE order_id IN (SELECT id FROM old_orders)`,
want: `WITH old_orders AS (
SELECT id FROM orders
SELECT id FROM orders
WHERE created_at < ? AND status = ?
)
DELETE FROM order_items
DELETE FROM order_items
WHERE order_id IN (SELECT id FROM old_orders)`,
},
{
Expand Down Expand Up @@ -185,134 +185,6 @@ func TestPgSqlParser_Redact(t *testing.T) {
}
}

func TestPgSqlParser_ExtractTableNames(t *testing.T) {
tests := []struct {
name string
sql string
want []string
wantErr bool
}{
{
name: "simple select",
sql: "SELECT * FROM users",
want: []string{"users"},
},
{
name: "select with join",
sql: "SELECT * FROM users u JOIN orders o ON u.id = o.user_id",
want: []string{"orders", "users"},
},
{
name: "select with schema qualified tables",
sql: "SELECT * FROM public.users JOIN sales.orders ON users.id = orders.user_id",
want: []string{"public.users", "sales.orders"},
},
{
name: "insert statement",
sql: "INSERT INTO users (name, email) VALUES ('John', '[email protected]')",
want: []string{"users"},
},
{
name: "update statement",
sql: "UPDATE users SET last_login = NOW() WHERE id = 1",
want: []string{"users"},
},
{
name: "delete statement",
sql: "DELETE FROM users WHERE id = 1",
want: []string{"users"},
},
{
name: "with clause",
sql: `WITH active_users AS (
SELECT * FROM users WHERE status = 'active'
)
SELECT * FROM active_users au
JOIN orders o ON o.user_id = au.id`,
want: []string{"orders", "users"},
},
{
name: "subquery in where clause",
sql: `SELECT * FROM orders
WHERE user_id IN (SELECT id FROM users WHERE status = 'active')`,
want: []string{"orders", "users"},
},
{
name: "multiple schema qualified tables with aliases",
sql: `SELECT u.name, o.total, p.status
FROM public.users u
JOIN sales.orders o ON u.id = o.user_id
LEFT JOIN shipping.packages p ON o.id = p.order_id`,
want: []string{"public.users", "sales.orders", "shipping.packages"},
},
{
name: "truncated query with ...",
sql: "SELECT * FROM users JOIN orders ON users.id = orders.user_id AND...",
want: []string{"users", "orders"},
},
{
name: "truncated query with incomplete comment",
sql: "SELECT * FROM users JOIN orders ON users.id = orders.user_id /* some comment that gets truncated...",
want: []string{"users", "orders"},
},
{
name: "truncated query mid-table name",
sql: "SELECT * FROM users JOIN ord...",
want: []string{"users", "ord..."},
},
{
name: "truncated query with schema qualified tables",
sql: "SELECT * FROM public.users JOIN sales.orders ON users.id = orders.user_id AND...",
want: []string{"public.users", "sales.orders"},
},
{
name: "query with table.* expression",
sql: "SELECT u.*, o.* FROM users u JOIN orders o ON u.id = o.user_id",
want: []string{"users", "orders"},
},
{
name: "query with type cast",
sql: "SELECT u.id, '2024-03-20'::timestamp FROM users u",
want: []string{"users"},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := ExtractTableNames(tt.sql)
if (err != nil) != tt.wantErr {
t.Errorf("ExtractTableNames() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr {
if len(got) != len(tt.want) {
t.Errorf("ExtractTableNames()\nGOT = %v\nWANT = %v", got, tt.want)
return
}
// Compare slices ignoring order since table names might come in different order
gotMap := make(map[string]bool)
wantMap := make(map[string]bool)
for _, table := range got {
gotMap[table] = true
}
for _, table := range tt.want {
wantMap[table] = true
}
for table := range gotMap {
if !wantMap[table] {
t.Errorf("ExtractTableNames() got unexpected table = %v", table)
}
}
for table := range wantMap {
if !gotMap[table] {
t.Errorf("ExtractTableNames() missing expected table = %v", table)
}
}
}
})
}
}

func TestContainsReservedKeywords(t *testing.T) {
tests := []struct {
name string
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strings"
"time"

"github.com/DataDog/go-sqllexer"
"github.com/go-kit/log"
"go.uber.org/atomic"

Expand Down Expand Up @@ -55,6 +56,7 @@ type QueryDetails struct {
collectInterval time.Duration
entryHandler loki.EntryHandler
tableRegistry *TableRegistry
normalizer *sqllexer.Normalizer

logger log.Logger
running *atomic.Bool
Expand All @@ -68,6 +70,7 @@ func NewQueryDetails(args QueryDetailsArguments) (*QueryDetails, error) {
collectInterval: args.CollectInterval,
entryHandler: args.EntryHandler,
tableRegistry: args.TableRegistry,
normalizer: sqllexer.NewNormalizer(sqllexer.WithCollectTables(true), sqllexer.WithCollectComments(true)),
logger: log.With(args.Logger, "collector", QueryDetailsCollector),
running: &atomic.Bool{},
}, nil
Expand Down Expand Up @@ -129,23 +132,25 @@ func (c QueryDetails) fetchAndAssociate(ctx context.Context) error {
for rs.Next() {
var queryID, queryText string
var databaseName database
err := rs.Scan(
&queryID,
&queryText,
&databaseName,
)
err := rs.Scan(&queryID, &queryText, &databaseName)
if err != nil {
level.Error(c.logger).Log("msg", "failed to scan result set for pg_stat_statements", "err", err)
continue
}

queryText, err = RemoveComments(c.normalizer, queryText)
if err != nil {
level.Error(c.logger).Log("msg", "failed to remove comments", "err", err)
continue
}

c.entryHandler.Chan() <- database_observability.BuildLokiEntry(
logging.LevelInfo,
OP_QUERY_ASSOCIATION,
fmt.Sprintf(`queryid="%s" querytext=%q datname="%s" engine="postgres"`, queryID, queryText, databaseName),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we're removing the db engine? I understand that postgres is uniquely identified by the other values, I.E. queryid and datname. That said, having the engine listed doesn't increase cardinality, and isn't a huge increase in log line size either.

I don't have a strong opinion either way, just curious what the motivation is here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that postgres is uniquely identified by the other values, I.E. queryid and datname. That said, having the engine listed doesn't increase cardinality, and isn't a huge increase in log line size either.

Yes that's correct, it doesn't really change the cardinality. It was added there at early development stage though and we've removed it from various places over time (this is likely the last place where it appears). Not a big deal, doing it just for consistency at this point.

fmt.Sprintf(`queryid="%s" querytext=%q datname="%s"`, queryID, queryText, databaseName),
)

tables, err := c.tryTokenizeTableNames(queryText)
tables, err := TokenizeTableNames(c.normalizer, queryText)
if err != nil {
level.Error(c.logger).Log("msg", "failed to tokenize table names", "err", err)
continue
Expand All @@ -160,7 +165,7 @@ func (c QueryDetails) fetchAndAssociate(ctx context.Context) error {
c.entryHandler.Chan() <- database_observability.BuildLokiEntry(
logging.LevelInfo,
OP_QUERY_PARSED_TABLE_NAME,
fmt.Sprintf(`queryid="%s" datname="%s" table="%s" engine="postgres" validated="%t"`, queryID, databaseName, table, validated),
fmt.Sprintf(`queryid="%s" datname="%s" table="%s" validated="%t"`, queryID, databaseName, table, validated),
)
}
}
Expand All @@ -173,12 +178,29 @@ func (c QueryDetails) fetchAndAssociate(ctx context.Context) error {
return nil
}

func (c QueryDetails) tryTokenizeTableNames(sqlText string) ([]string, error) {
func TokenizeTableNames(normalizer *sqllexer.Normalizer, sqlText string) ([]string, error) {
sqlText = strings.TrimSuffix(sqlText, "...")
tables, err := database_observability.ExtractTableNames(sqlText)
_, metadata, err := normalizer.Normalize(sqlText, sqllexer.WithDBMS(sqllexer.DBMSPostgres))
if err != nil {
return nil, fmt.Errorf("failed to tokenize table names: %w", err)
}

return metadata.Tables, nil
}

func RemoveComments(normalizer *sqllexer.Normalizer, sqlText string) (string, error) {
_, metadata, err := normalizer.Normalize(sqlText, sqllexer.WithDBMS(sqllexer.DBMSPostgres))
if err != nil {
return nil, fmt.Errorf("failed to extract table names: %w", err)
return sqlText, fmt.Errorf("failed to redact comments: %w", err)
}

if len(metadata.Comments) == 0 {
return sqlText, nil
}

for _, comment := range metadata.Comments {
sqlText = strings.ReplaceAll(sqlText, comment, "")
}

return tables, nil
return strings.TrimSpace(sqlText), nil
}
Loading
Loading