Skip to content

Commit 40652a6

Browse files
authored
SONARPY-1711: Rule S5659: do not raise under certain use of get_unverified_header() (#2041)
1 parent 696732f commit 40652a6

File tree

3 files changed

+202
-3
lines changed

3 files changed

+202
-3
lines changed

python-checks/src/main/java/org/sonar/python/checks/JwtVerificationCheck.java

Lines changed: 103 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,30 @@
1919
*/
2020
package org.sonar.python.checks;
2121

22+
import java.util.Collection;
23+
import java.util.List;
2224
import java.util.Optional;
2325
import java.util.Set;
26+
import java.util.stream.Stream;
2427
import javax.annotation.Nullable;
2528
import org.sonar.check.Rule;
2629
import org.sonar.plugins.python.api.PythonSubscriptionCheck;
2730
import org.sonar.plugins.python.api.SubscriptionContext;
2831
import org.sonar.plugins.python.api.symbols.Symbol;
32+
import org.sonar.plugins.python.api.symbols.Usage;
33+
import org.sonar.plugins.python.api.tree.Argument;
34+
import org.sonar.plugins.python.api.tree.AssignmentStatement;
2935
import org.sonar.plugins.python.api.tree.CallExpression;
3036
import org.sonar.plugins.python.api.tree.DictionaryLiteral;
3137
import org.sonar.plugins.python.api.tree.Expression;
38+
import org.sonar.plugins.python.api.tree.ExpressionList;
3239
import org.sonar.plugins.python.api.tree.KeyValuePair;
3340
import org.sonar.plugins.python.api.tree.ListLiteral;
3441
import org.sonar.plugins.python.api.tree.Name;
42+
import org.sonar.plugins.python.api.tree.QualifiedExpression;
3543
import org.sonar.plugins.python.api.tree.RegularArgument;
3644
import org.sonar.plugins.python.api.tree.StringLiteral;
45+
import org.sonar.plugins.python.api.tree.SubscriptionExpression;
3746
import org.sonar.plugins.python.api.tree.Tree;
3847
import org.sonar.plugins.python.api.tree.Tree.Kind;
3948
import org.sonar.plugins.python.api.tree.Tuple;
@@ -55,6 +64,8 @@ public class JwtVerificationCheck extends PythonSubscriptionCheck {
5564
"python_jwt.verify_jwt",
5665
"jwt.verify_jwt");
5766

67+
private static final Set<String> ALLOWED_KEYS_ACCESS = Set.of("jku", "jwk", "kid", "x5u", "x5c", "x5t", "xt#256");
68+
5869
private static final Set<String> WHERE_VERIFY_KWARG_SHOULD_BE_TRUE_FQNS = Set.of(
5970
"jwt.decode",
6071
"jose.jws.verify");
@@ -66,8 +77,7 @@ public class JwtVerificationCheck extends PythonSubscriptionCheck {
6677
"jose.jws.get_unverified_header",
6778
"jose.jws.get_unverified_headers",
6879
"jose.jwt.get_unverified_claims",
69-
"jose.jws.get_unverified_claims"
70-
);
80+
"jose.jws.get_unverified_claims");
7181

7282
private static final String VERIFY_SIGNATURE_KEYWORD = "verify_signature";
7383

@@ -97,7 +107,7 @@ private static void verifyCallExpression(SubscriptionContext ctx) {
97107
Optional.ofNullable(TreeUtils.firstAncestorOfKind(call, Kind.FILE_INPUT, Kind.FUNCDEF))
98108
.filter(scriptOrFunction -> !TreeUtils.hasDescendant(scriptOrFunction, JwtVerificationCheck::isCallToVerifyJwt))
99109
.ifPresent(scriptOrFunction -> ctx.addIssue(call, MESSAGE));
100-
} else if (UNVERIFIED_FQNS.contains(calleeFqn)) {
110+
} else if (UNVERIFIED_FQNS.contains(calleeFqn) && !accessOnlyAllowedHeaderKeys(call)) {
101111
Optional.ofNullable(TreeUtils.nthArgumentOrKeyword(0, "", call.arguments()))
102112
.flatMap(TreeUtils.toOptionalInstanceOfMapper(RegularArgument.class))
103113
.map(RegularArgument::expression)
@@ -175,4 +185,94 @@ private static boolean isCallToVerifyJwt(Tree t) {
175185
.isPresent();
176186
}
177187

188+
private static boolean accessOnlyAllowedHeaderKeys(CallExpression call) {
189+
Tree assignment = TreeUtils.firstAncestorOfKind(call, Tree.Kind.ASSIGNMENT_STMT);
190+
Stream<StringLiteral> headerKeysAccessedDirectly = accessToHeaderKeyWithoutAssignment(call);
191+
if (assignment == null) {
192+
return areStringLiteralsPartOfAllowedKeys(headerKeysAccessedDirectly);
193+
} else {
194+
List<Expression> lhsExpressions = ((AssignmentStatement) assignment).lhsExpressions().stream()
195+
.map(ExpressionList::expressions)
196+
.flatMap(Collection::stream).toList();
197+
if (lhsExpressions.size() == 1 && lhsExpressions.get(0).is(Tree.Kind.NAME)) {
198+
Name name = (Name) lhsExpressions.get(0);
199+
Symbol symbol = name.symbol();
200+
if (symbol != null) {
201+
Stream<StringLiteral> argumentsOfGet = usagesAccessedByGet(symbol, call);
202+
Stream<StringLiteral> argumentsOfSubscription = usagesAccessedBySubscription(symbol, call);
203+
Stream<StringLiteral> headerKeysAccessFromAssignedValues = Stream.concat(argumentsOfGet, argumentsOfSubscription);
204+
return areStringLiteralsPartOfAllowedKeys(Stream.concat(headerKeysAccessFromAssignedValues, headerKeysAccessedDirectly));
205+
}
206+
}
207+
}
208+
return false;
209+
}
210+
211+
private static boolean areStringLiteralsPartOfAllowedKeys(Stream<StringLiteral> literals) {
212+
var literalList = literals.toList();
213+
return !literalList.isEmpty() && literalList.stream().allMatch(str -> ALLOWED_KEYS_ACCESS.contains(str.trimmedQuotesValue()));
214+
}
215+
216+
private static Stream<StringLiteral> accessToHeaderKeyWithoutAssignment(CallExpression call) {
217+
Stream<CallExpression> callExpressionFromGetUnverifiedHeaders = getCallExprWhereDictIsAccessedWithGet(Stream.of(call.parent()));
218+
Stream<Argument> argumentsOfCallExpr = getArgumentsFromCallExpr(callExpressionFromGetUnverifiedHeaders);
219+
Stream<StringLiteral> stringLiteralArgumentsFromGet = getStringLiteralArguments(argumentsOfCallExpr);
220+
Stream<SubscriptionExpression> subscriptionFromGetUnverifiedHeaders = getSubscriptions(Stream.of(call.parent()));
221+
Stream<StringLiteral> stringLiteralArgumentFromSubscription = getSubscriptsStringLiteral(subscriptionFromGetUnverifiedHeaders);
222+
return Stream.concat(stringLiteralArgumentsFromGet, stringLiteralArgumentFromSubscription);
223+
}
224+
225+
private static Stream<StringLiteral> usagesAccessedByGet(Symbol symbol, CallExpression call) {
226+
var usages = getForwardUsages(symbol, call);
227+
var parentOfUsages = usages.map(Usage::tree).map(Tree::parent);
228+
var callExpressionsFromUsages = getCallExprWhereDictIsAccessedWithGet(parentOfUsages);
229+
return getStringLiteralArguments(getArgumentsFromCallExpr(callExpressionsFromUsages));
230+
}
231+
232+
private static Stream<Argument> getArgumentsFromCallExpr(Stream<CallExpression> callExprs) {
233+
return callExprs.map(CallExpression::arguments).flatMap(Collection::stream);
234+
}
235+
236+
private static Stream<Usage> getForwardUsages(Symbol symbol, CallExpression call) {
237+
return symbol.usages().stream()
238+
.filter(usage -> usage.tree().firstToken().line() > call.callee().firstToken().line());
239+
}
240+
241+
private static Stream<CallExpression> getCallExprWhereDictIsAccessedWithGet(Stream<Tree> parentQualifiedExpr) {
242+
return parentQualifiedExpr
243+
.filter(parent -> parent.is(Tree.Kind.QUALIFIED_EXPR))
244+
.flatMap(TreeUtils.toStreamInstanceOfMapper(QualifiedExpression.class))
245+
.filter(expr -> "get".equals(expr.name().name()))
246+
.filter(expr -> expr.parent().is(Kind.CALL_EXPR))
247+
.map(QualifiedExpression::parent)
248+
.flatMap(TreeUtils.toStreamInstanceOfMapper(CallExpression.class));
249+
}
250+
251+
private static Stream<StringLiteral> getStringLiteralArguments(Stream<Argument> arguments) {
252+
return arguments.filter(arg -> arg.is(Tree.Kind.REGULAR_ARGUMENT))
253+
.flatMap(TreeUtils.toStreamInstanceOfMapper(RegularArgument.class))
254+
.map(RegularArgument::expression)
255+
.flatMap(TreeUtils.toStreamInstanceOfMapper(StringLiteral.class));
256+
}
257+
258+
private static Stream<StringLiteral> usagesAccessedBySubscription(Symbol symbol, CallExpression call) {
259+
var usages = getForwardUsages(symbol, call);
260+
var parentFromUsages = usages.map(Usage::tree).map(Tree::parent);
261+
var subscriptionsFromUsages = getSubscriptions(parentFromUsages);
262+
return getSubscriptsStringLiteral(subscriptionsFromUsages);
263+
}
264+
265+
private static Stream<SubscriptionExpression> getSubscriptions(Stream<Tree> subscriptions) {
266+
return subscriptions
267+
.filter(subscription -> subscription.is(Tree.Kind.SUBSCRIPTION))
268+
.flatMap(TreeUtils.toStreamInstanceOfMapper(SubscriptionExpression.class));
269+
}
270+
271+
private static Stream<StringLiteral> getSubscriptsStringLiteral(Stream<SubscriptionExpression> subscriptions) {
272+
return subscriptions.map(SubscriptionExpression::subscripts)
273+
.map(ExpressionList::expressions)
274+
.flatMap(Collection::stream)
275+
.flatMap(TreeUtils.toStreamInstanceOfMapper(StringLiteral.class));
276+
}
277+
178278
}

python-checks/src/test/resources/checks/jwtVerification.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,53 @@ def pyjwt_decode_token_secure_2(token):
4242

4343
def pyjwt_decode_unverified_header(token):
4444
return jwt.get_unverified_header(token) # Noncompliant
45+
46+
def get_unverified_header_access(token:str):
47+
header = jwt.get_unverified_header(token) # Noncompliant
48+
print(f"Extra data in header: {header['extra']}")
49+
50+
def get_unverified_header_return(token: str) -> Dict[str, str]:
51+
header = jwt.get_unverified_header(token) # Noncompliant
52+
return header
53+
54+
def get_unverified_header_non_compliant_sanity_check(token: str, some_object, other_call) -> Dict[str, str]:
55+
other = header = jwt.get_unverified_header(token) # Noncompliant
56+
other.get("kid")
57+
58+
some_object[0] = jwt.get_unverified_header(token) # Noncompliant
59+
header = jwt.get_unverified_header(token) # Noncompliant
60+
61+
def get_unverified_header_sanity_checks(token: str , other_call) -> Dict[str, str]:
62+
header = jwt.get_unverified_header(token) # Noncompliant
63+
header.test("kid")
64+
header.get
65+
header.get()
66+
header[slice(12)]
67+
other_call(jwt.get_unverified_header(token).get("x5u"))
68+
return jwt.get_unverified_header(token).get("kid")
69+
70+
def get_unverified_header_used(token: str, do_other_things_with):
71+
header = jwt.get_unverified_header(token) # Noncompliant
72+
return do_other_things_with(header)
73+
74+
def get_unverified_header_disallowed_access(token: str):
75+
header = jwt.get_unverified_header(token) # Noncompliant
76+
kid = header.get("kid")
77+
not_kid = header.get("extra")
78+
79+
header = jwt.get_unverified_header(token) # Noncompliant
80+
kid = header.get("kid")
81+
not_kid = header["extra"]
82+
83+
def get_unverified_header_compliant(token: str, keys):
84+
header = jwt.get_unverified_header(token) # Compliant: only "kid" is accessed
85+
kid = header.get("kid")
86+
87+
x5u = jwt.get_unverified_header(token).get("x5u") # Compliant
88+
89+
x5t = jwt.get_unverified_header(token)["x5t"] # Compliant
90+
header = jwt.get_unverified_header(token) # Compliant
91+
jku = header["jku"]
92+
key = keys[jku]
93+
claims = jwt.decode(token, key, algorithms=["HS256"])
94+
return claims

python-checks/src/test/resources/checks/pythonJoseVerification.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,3 +221,52 @@ async def compliant_verify(token: str | None = None):
221221
class TestInfiniteRecursion():
222222
x = x
223223
payload = jwt.decode(token, None, options=x)
224+
225+
def pyjwt_decode_unverified_header(token):
226+
return jwt.get_unverified_header(token) # Noncompliant
227+
228+
def get_unverified_header_access(token:str):
229+
header = jwt.get_unverified_header(token) # Noncompliant
230+
print(f"Extra data in header: {header['extra']}")
231+
232+
def get_unverified_header_return(token: str) -> Dict[str, str]:
233+
header = jwt.get_unverified_header(token) # Noncompliant
234+
return header
235+
236+
def get_unverified_header_sanity_checks(token: str, some_object) -> Dict[str, str]:
237+
other = header = jwt.get_unverified_header(token) # Noncompliant
238+
other.get("kid")
239+
240+
some_object[0] = jwt.get_unverified_header(token) # Noncompliant
241+
header = jwt.get_unverified_header(token) # Noncompliant
242+
header.test("kid")
243+
header.get
244+
header.get()
245+
header[slice(12)]
246+
return header
247+
248+
def get_unverified_header_used(token: str, do_other_things_with):
249+
header = jwt.get_unverified_header(token) # Noncompliant
250+
return do_other_things_with(header)
251+
252+
def get_unverified_header_disallowed_access(token: str):
253+
header = jwt.get_unverified_header(token) # Noncompliant
254+
kid = header.get("kid")
255+
not_kid = header.get("extra")
256+
257+
header = jwt.get_unverified_header(token) # Noncompliant
258+
kid = header.get("kid")
259+
not_kid = header["extra"]
260+
261+
def get_unverified_header_compliant(token: str, keys):
262+
header = jwt.get_unverified_header(token) # Compliant: only "kid" is accessed
263+
kid = header.get("kid")
264+
265+
x5u = jwt.get_unverified_header(token).get("x5u") # Compliant
266+
267+
x5t = jwt.get_unverified_header(token)["x5t"] # Compliant
268+
header = jwt.get_unverified_header(token) # Compliant
269+
jku = header["jku"]
270+
key = keys[jku]
271+
claims = jwt.decode(token, key, algorithms=["HS256"])
272+
return claims

0 commit comments

Comments
 (0)