11import libcst as cst
2+ from libcst .codemod import CodemodContext
23
34from codemodder .codemods .base_codemod import Metadata , ReviewGuidance , ToolRule
45from codemodder .codemods .libcst_transformer import (
78)
89from codemodder .codemods .utils_mixin import NameAndAncestorResolutionMixin
910from codemodder .codetf import Reference
11+ from codemodder .file_context import FileContext
12+ from codemodder .result import Result
1013from core_codemods .sonar .api import SonarCodemod
1114
1215rules = [
@@ -23,28 +26,55 @@ class SonarUseSecureProtocolsTransformer(
2326):
2427 change_description = "Modified URLs or calls to use secure protocols"
2528
29+ def __init__ (
30+ self ,
31+ context : CodemodContext ,
32+ results : list [Result ] | None ,
33+ file_context : FileContext ,
34+ _transformer : bool = False ,
35+ ):
36+ self .nodes_memory_with_context_name : dict [cst .CSTNode , str ] = {}
37+ super ().__init__ (context , results , file_context , _transformer )
38+
2639 def _match_and_handle_statement (
2740 self , possible_smtp_call , original_node_statement , updated_node_statement
2841 ):
2942 maybe_name = self .find_base_name (possible_smtp_call )
3043 match possible_smtp_call :
3144 case cst .Call () if maybe_name == "smtplib.SMTP" :
45+ # get the stored context_name or create a new one:
46+ if possible_smtp_call in self .nodes_memory_with_context_name :
47+ context_name = self .nodes_memory_with_context_name [
48+ possible_smtp_call
49+ ]
50+ else :
51+ context_name = self .generate_available_name (
52+ original_node_statement , ["smtp_context" ]
53+ )
54+
3255 new_statements = []
3356 new_statements .append (
34- cst .parse_statement ("smtp_context = ssl.create_default_context()" )
57+ cst .parse_statement (
58+ f"{ context_name } = ssl.create_default_context()"
59+ )
3560 )
3661 new_statements .append (
37- cst .parse_statement ("smtp_context.verify_mode = ssl.CERT_REQUIRED" )
62+ cst .parse_statement (
63+ f"{ context_name } .verify_mode = ssl.CERT_REQUIRED"
64+ )
3865 )
3966 new_statements .append (
40- cst .parse_statement ("smtp_context .check_hostname = True" )
67+ cst .parse_statement (f" { context_name } .check_hostname = True" )
4168 )
4269 new_statements .append (updated_node_statement )
43- # TODO don't append this if we changed the call to SSL version
44- new_statements .append (
45- cst .parse_statement ("smtp.starttls(context=smtp_context)" )
46- )
47- self .add_needed_import ("smtp" )
70+ # don't append this if we changed the call to SSL version
71+ if possible_smtp_call in self .nodes_memory_with_context_name :
72+ self .nodes_memory_with_context_name .pop (possible_smtp_call )
73+ else :
74+ new_statements .append (
75+ cst .parse_statement (f"smtplib.starttls(context={ context_name } )" )
76+ )
77+ self .add_needed_import ("smtplib" )
4878 self .add_needed_import ("ssl" )
4979 self .report_change (possible_smtp_call )
5080 return cst .FlattenSentinel (new_statements )
@@ -70,6 +100,8 @@ def leave_Call(self, original_node, updated_node):
70100 self .report_change (original_node )
71101 self .add_needed_import ("ftplib" )
72102 return updated_node .with_changes (func = new_func )
103+ # Just using ssl.create_default_context() may not be enough for older python versions
104+ # See https://stackoverflow.com/questions/33857698/sending-email-from-python-using-starttls
73105 case "smtplib.SMTP" :
74106 # port is the second positional, check that
75107 maybe_port_value = (
@@ -94,37 +126,28 @@ def leave_Call(self, original_node, updated_node):
94126 )
95127 match maybe_port_value :
96128 case None :
97- new_func = cst .parse_expression ("smtplib.SMTP_SSL" )
98- self .report_change (original_node )
99- self .add_needed_import ("smtplib" )
100- new_args = [
101- * original_node .args ,
102- cst .Arg (
103- keyword = cst .Name ("context" ),
104- value = cst .Name ("smtp_context" ),
105- ),
106- ]
107- return updated_node .with_changes (
108- func = new_func , args = new_args
109- )
110- # TODO still needs the context object statements here
111- # TODO only change this if it mathces the statement pattern in leave_statemenet
112- # TODO use a flag for this in visit_SimpleStatement
129+ self ._change_to_smtp_ssl (original_node , updated_node )
113130 case cst .Integer () if maybe_port_value == "0" :
114- new_func = cst .parse_expression ("smtplib.SMTP_SSL" )
115- self .report_change (original_node )
116- self .add_needed_import ("smtplib" )
117- new_args = [
118- * original_node .args ,
119- cst .Arg (
120- keyword = cst .Name ("context" ),
121- value = cst .Name ("smtp_context" ),
122- ),
123- ]
124- return updated_node .with_changes (func = new_func )
131+ self ._change_to_smtp_ssl (original_node , updated_node )
125132
126133 return updated_node
127134
135+ def _change_to_smtp_ssl (self , original_node , updated_node ):
136+ # remember this node so we don't add the starttls
137+ new_func = cst .parse_expression ("smtplib.SMTP_SSL" )
138+
139+ context_name = self .generate_available_name (original_node , ["smtp_context" ])
140+ self .nodes_memory_with_context_name [original_node ] = context_name
141+
142+ new_args = [
143+ * original_node .args ,
144+ cst .Arg (
145+ keyword = cst .Name ("context" ),
146+ value = cst .Name (context_name ),
147+ ),
148+ ]
149+ return updated_node .with_changes (func = new_func , args = new_args )
150+
128151 def leave_SimpleString (
129152 self , original_node : cst .SimpleString , updated_node : cst .SimpleString
130153 ) -> cst .BaseExpression :
0 commit comments