Skip to content
Closed
5 changes: 5 additions & 0 deletions checkers/checker.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

goAnalysis "globstar.dev/analysis"
"globstar.dev/checkers/javascript"
"globstar.dev/checkers/python"
"globstar.dev/pkg/analysis"
)

Expand Down Expand Up @@ -69,6 +70,10 @@ var AnalyzerRegistry = []Analyzer{
TestDir: "checkers/javascript/testdata", // relative to the repository root
Analyzers: []*goAnalysis.Analyzer{javascript.NoDoubleEq, javascript.SQLInjection},
},
{
TestDir: "checkers/python/testdata", // relative to the repository root
Analyzers: []*goAnalysis.Analyzer{python.AvoidUnsanitizedSQL},
},
}

func LoadGoCheckers() []*goAnalysis.Analyzer {
Expand Down
257 changes: 257 additions & 0 deletions checkers/python/avoid-unsanitized-sql.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
package python

import (
"fmt"
"path/filepath"
"strings"

sitter "github.com/smacker/go-tree-sitter"
"globstar.dev/analysis"
)

var AvoidUnsanitizedSQL = &analysis.Analyzer{
Name: "avoid-unsanitized-sql",
Language: analysis.LangPy,
Description: "Check if SQL query is sanitized",
Category: analysis.CategorySecurity,
Severity: analysis.SeverityCritical,
Run: checkSQLInjection,
}

// checkSQLInjection is the rule callback that inspects each call node.
func checkSQLInjection(pass *analysis.Pass) (interface{}, error) {
analysis.Preorder(pass, func(node *sitter.Node) {
source := pass.FileContext.Source

// Only process call nodes.
if node.Type() != "call" {
return
}

// Extract the function part (e.g. cursor.execute).
functionNode := node.ChildByFieldName("function")
if functionNode == nil {
return
}

// Proceed only if the function is one of our recognized SQL methods.
if !isSQLExecuteMethod(functionNode, source) {
return
}

// Check the first argument.
argsNode := node.ChildByFieldName("arguments")
if argsNode == nil {
return
}
firstArg := getNthChild(argsNode, 0)
if firstArg == nil {
return
}

// If the query string is built unsafely, report an issue.
if isUnsafeString(firstArg, source) {
pass.Report(pass, node, "Concatenated string in SQL query is an SQL injection threat")
return
}

// If the argument is an identifier, trace its origin.
if firstArg.Type() == "identifier" {
varName := firstArg.Content(source)
traceVariableOrigin(pass, varName, node, make(map[string]bool), make(map[string]bool), source)
}
})

return nil, nil
}

// --- Helper Functions ---

func isSQLExecuteMethod(node *sitter.Node, source []byte) bool {
var funcName string
switch node.Type() {
case "identifier":
funcName = node.Content(source)
case "attribute":
attr := node.ChildByFieldName("attribute")
if attr != nil {
funcName = attr.Content(source)
}
}

sqlMethods := map[string]bool{
"execute": true,
"executemany": true,
"executescript": true,
}
return sqlMethods[funcName]
}

func isUnsafeString(node *sitter.Node, source []byte) bool {
// Check for f-strings with interpolation.
if node.Type() == "fstring" {
for i := 0; i < int(node.ChildCount()); i++ {
if node.Child(i).Type() == "interpolation" {
return true
}
}
}

// Check for unsafe binary concatenation.
if node.Type() == "binary_operator" {
op := node.ChildByFieldName("operator")
if op != nil && op.Content(source) == "+" {
return containsVariable(node.ChildByFieldName("left"), source) ||
containsVariable(node.ChildByFieldName("right"), source)
}
}

return false
}

func traceVariableOrigin(pass *analysis.Pass, varName string, originalNode *sitter.Node,
visitedVars map[string]bool, visitedFiles map[string]bool, source []byte) {

if visitedVars[varName] {
return
}
visitedVars[varName] = true

if traceLocalAssignments(pass, varName, originalNode, visitedVars, visitedFiles, source) {
return
}

traceCrossFileImports(pass, varName, originalNode, visitedVars, visitedFiles, source)
}

func traceLocalAssignments(pass *analysis.Pass, varName string, originalNode *sitter.Node,
visitedVars map[string]bool, visitedFiles map[string]bool, source []byte) bool {

query := `(assignment left: (identifier) @var right: (_) @value)`
q, err := sitter.NewQuery([]byte(query), pass.Analyzer.Language.Grammar())
if err != nil {
return false
}
defer q.Close()

cursor := sitter.NewQueryCursor()
defer cursor.Close()
cursor.Exec(q, pass.FileContext.Ast)

for {
match, ok := cursor.NextMatch()
if !ok {
break
}

var varNode, valueNode *sitter.Node
for idx, capture := range match.Captures {
switch q.CaptureNameForId(uint32(idx)) {
case "var":
varNode = capture.Node
case "value":
valueNode = capture.Node
}
}

if varNode != nil && varNode.Content(source) == varName {
if isUnsafeString(valueNode, source) {
pass.Report(pass, originalNode, fmt.Sprintf("Variable '%s' originates from an unsafe string", varName))
return true
}

if valueNode.Type() == "identifier" {
newVar := valueNode.Content(source)
traceVariableOrigin(pass, newVar, originalNode, visitedVars, visitedFiles, source)
return true
}
}
}
return false
}

func traceCrossFileImports(pass *analysis.Pass, varName string, originalNode *sitter.Node,
visitedVars map[string]bool, visitedFiles map[string]bool, source []byte) {

query := `(
(import_from_statement
module_name: (dotted_name) @module
name: (dotted_name) @imported_var
) @import
)`
q, err := sitter.NewQuery([]byte(query), pass.Analyzer.Language.Grammar())
if err != nil {
return
}
defer q.Close()

cursor := sitter.NewQueryCursor()
defer cursor.Close()
cursor.Exec(q, pass.FileContext.Ast)

for {
match, ok := cursor.NextMatch()
if !ok {
break
}

var moduleNode, varNode *sitter.Node
for idx, capture := range match.Captures {
switch q.CaptureNameForId(uint32(idx)) {
case "module":
moduleNode = capture.Node
case "imported_var":
varNode = capture.Node
}
}

if varNode != nil && varNode.Content(source) == varName && moduleNode != nil {
modulePath := convertImportToPath(moduleNode.Content(source))
if visitedFiles[modulePath] {
continue
}
visitedFiles[modulePath] = true

for _, file := range pass.Files {
if strings.HasSuffix(file.FilePath, modulePath) {
// Create a temporary analyzer context for the imported file.
tempPass := &analysis.Pass{
Analyzer: pass.Analyzer,
FileContext: file,
Files: pass.Files,
Report: pass.Report, // Reuse the report function.
}
traceVariableOrigin(tempPass, varName, originalNode, visitedVars, visitedFiles, file.Source)
}
}
}
}
}

func containsVariable(node *sitter.Node, source []byte) bool {
if node == nil {
return false
}
switch node.Type() {
case "identifier", "attribute":
return true
case "binary_operator":
return containsVariable(node.ChildByFieldName("left"), source) ||
containsVariable(node.ChildByFieldName("right"), source)
case "parenthesized_expression":
return containsVariable(node.NamedChild(0), source)
default:
return false
}
}

func getNthChild(node *sitter.Node, n int) *sitter.Node {
if n < int(node.ChildCount()) {
return node.Child(n)
}
return nil
}

func convertImportToPath(importStr string) string {
return strings.ReplaceAll(importStr, ".", string(filepath.Separator)) + ".py"
}
44 changes: 44 additions & 0 deletions checkers/python/avoid-unsanitized-sql.test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@

import sqlite3
from fastapi import FastAPI, Query
import sqlite3

app = FastAPI()

def execute_unsafe_query(query: str):
conn = sqlite3.connect("test.db")
cursor = conn.cursor()
#<expect-error>
cursor.execute(query) #unsafe with user input
result = cursor.fetchall()
conn.commit()
conn.close()
return result

def better_query(query: str, params):
conn = sqlite3.connect("test.db")
cursor = conn.cursor()
cursor.execute(query, params) #safe to execute with user input
result = cursor.fetchall()
conn.commit()
conn.close()
return result


@app.get("/unsafe_query/")
def unsafe_query(user_input: str):
#f-string case

query = f"SELECT * FROM users WHERE name = {user_input}"
#binary operator case

query2= "SELECT * FROM users WHERE name ="+ user_input

#should not identify this as an error
query3= "SELECT * FROM user WHERE name= ?"
result = execute_unsafe_query(query)
result2= execute_unsafe_query(query=query2)

result3= better_query(query=query3, params=(user_input,))

return {"result": result, "result2": result2, "result3": result3}
42 changes: 42 additions & 0 deletions checkers/python/testdata/avoid-unsanitized-sql.test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import sqlite3
from fastapi import FastAPI, Query
import sqlite3

app = FastAPI()

def execute_unsafe_query(query: str):
conn = sqlite3.connect("test.db")
cursor = conn.cursor()
#<expect-error>
cursor.execute(query) #unsafe with user input
result = cursor.fetchall()
conn.commit()
conn.close()
return result

def better_query(query: str, params):
conn = sqlite3.connect("test.db")
cursor = conn.cursor()
cursor.execute(query, params) #safe to execute with user input
result = cursor.fetchall()
conn.commit()
conn.close()
return result


@app.get("/unsafe_query/")
def unsafe_query(user_input: str):
#f-string case

query = f"SELECT * FROM users WHERE name = {user_input}"
#binary operator case

query2= "SELECT * FROM users WHERE name ="+ user_input
#should not identify this as an error
query3= "SELECT * FROM user WHERE name= ?"
result = execute_unsafe_query(query)
result2= execute_unsafe_query(query=query2)

result3= better_query(query=query3, params=(user_input,))

return {"result": result, "result2": result2, "result3": result3}