@@ -245,6 +245,25 @@ private Optional<MethodCallExpr> validateExecuteCall(final MethodCallExpr execut
245245 return Optional .of (stmtObject );
246246 }
247247
248+ private Optional <Either <AssignExpr , LocalVariableDeclaration >>
249+ validateStatementCreationExprForHijack (
250+ final Either <MethodCallExpr , Either <AssignExpr , LocalVariableDeclaration >> stmtObject ) {
251+ if (stmtObject .isRight ()) {
252+ var maybelvd =
253+ stmtObject
254+ .getRight ()
255+ .ifLeftOrElseGet (
256+ ae ->
257+ ASTs .findEarliestLocalVariableDeclarationOf (
258+ ae , ae .getTarget ().asNameExpr ().getNameAsString ()),
259+ lvd -> Optional .of (lvd ));
260+ if (maybelvd .filter (lvd -> lvd instanceof ExpressionStmtVariableDeclaration ).isPresent ()) {
261+ return Optional .of (stmtObject .getRight ());
262+ }
263+ }
264+ return Optional .empty ();
265+ }
266+
248267 /** Checks if a local declaration can change types to a subtype. */
249268 private boolean canChangeTypes (final LocalVariableDeclaration localDeclaration ) {
250269 final var allNameExpr =
@@ -642,27 +661,68 @@ private boolean assignedOrDefinedInScope(
642661 return assignedInScope || definedInScope ;
643662 }
644663
645- private MethodCallExpr fixHijackedStatement (
664+ private Expression getConnectionExpression (
665+ final Either <AssignExpr , LocalVariableDeclaration > stmtCreation ) {
666+ return stmtCreation
667+ .ifLeftOrElseGet (
668+ ae -> ae .getValue ().asMethodCallExpr (),
669+ lvd -> lvd .getDeclaration ().getInitializer ().get ().asMethodCallExpr ())
670+ .getScope ()
671+ .get ();
672+ }
673+
674+ private MethodCallExpr fixByHijackedStatement (
646675 final Either <AssignExpr , LocalVariableDeclaration > stmtCreation ,
647676 final QueryParameterizer queryParameterizer ,
648677 final MethodCallExpr executeCall ) {
649678 var executeStmt = ASTs .findParentStatementFrom (executeCall ).get ();
650- // TODO this shouldn't work on anything but expression statements declaration, filter them out
651- // create new PreparedStatement object and set parameters
652-
653679 // get the statement object variable name
654680 final String stmtName =
655681 stmtCreation .ifLeftOrElseGet (
656682 a -> a .getTarget ().asNameExpr ().getNameAsString (), LocalVariableDeclaration ::getName );
683+ // generate a name for the new PreparedStatement object
684+ String pStmtName = generateNameWithSuffix (executeCall );
685+
686+ final String connName = getConnectionExpression (stmtCreation ).asNameExpr ().getNameAsString ();
687+
688+ var topStatement = executeStmt ;
657689
658690 // Replace the parameters with the `?` string and adds the `setParameter` calls
659691 // Also, get the top `setParameter` statement
660- var topStatement = gatherAndSetParameters (stmtName , executeStmt , queryParameterizer );
692+ topStatement = gatherAndSetParameters (pStmtName , topStatement , queryParameterizer );
661693
662- // .close the original statement and assign it to the stmt variable
694+ // Add PreparedStmt stmt = conn.prepareStatement() assignment
695+ MethodCallExpr prepareStatementCall =
696+ new MethodCallExpr (new NameExpr (connName ), "prepareStatement" , executeCall .getArguments ());
697+ ExpressionStmt pStmtCreation =
698+ new ExpressionStmt (
699+ new VariableDeclarationExpr (
700+ new VariableDeclarator (
701+ StaticJavaParser .parseType ("PreparedStatement" ),
702+ pStmtName ,
703+ prepareStatementCall )));
704+ ASTTransforms .addStatementBeforeStatement (topStatement , pStmtCreation );
705+ topStatement = pStmtCreation ;
706+
707+ // add stmt.close()
708+ Statement closeOriginal =
709+ new ExpressionStmt (new MethodCallExpr (new NameExpr (stmtName ), new SimpleName ("close" )));
710+ ASTTransforms .addStatementBeforeStatement (topStatement , closeOriginal );
711+
712+ // TODO will this work for every type of execute statement? or just executeQuery?
663713 // change execute statement
664- // TODO will this work for any type of execute statement?
665- return null ;
714+ executeCall .setName ("execute" );
715+ executeCall .setScope (new NameExpr (pStmtName ));
716+ executeCall .setArguments (new NodeList <>());
717+
718+ // add stmt = pstmt after executeCall
719+ Statement hijackAssignment =
720+ new ExpressionStmt (
721+ new AssignExpr (
722+ new NameExpr (stmtName ), new NameExpr (pStmtName ), AssignExpr .Operator .ASSIGN ));
723+ ASTTransforms .addStatementAfterStatement (executeStmt , hijackAssignment );
724+
725+ return prepareStatementCall ;
666726 }
667727
668728 /**
@@ -678,8 +738,7 @@ public Optional<MethodCallExpr> checkAndFix() {
678738 // validate the call itself first
679739 if (isParameterizationCandidate (executeCall ) && validateExecuteCall (executeCall ).isPresent ()) {
680740 // Now find the stmt creation expression, if any and validate it
681- final var stmtObject =
682- findStatementCreationExpr (executeCall ).flatMap (this ::validateStatementCreationExpr );
741+ final var stmtObject = findStatementCreationExpr (executeCall );
683742
684743 if (stmtObject .isPresent ()) {
685744 // Now look for injections
@@ -713,16 +772,21 @@ public Optional<MethodCallExpr> checkAndFix() {
713772 .map (Expression ::asNameExpr )
714773 .anyMatch (name -> assignedOrDefinedInScope (name , assignOrLVD )));
715774
716- if (queryp .getInjections ().isEmpty () || resolvedInScope ) {
775+ // No injections detected
776+ if (queryp .getInjections ().isEmpty ()) {
717777 return Optional .empty ();
718778 }
719779
720- if (nameInScope ) {
780+ // This means we can replace the Statement declaration or assignment
781+ if (!nameInScope
782+ && !resolvedInScope
783+ && stmtObject .flatMap (this ::validateStatementCreationExpr ).isPresent ()) {
721784 return Optional .of (fix (stmtObject .get (), queryp , executeCall ));
722785 }
723- if (stmtObject .get ().isRight ()) {
724- return Optional .of (
725- fixHijackedStatement (stmtObject .get ().getRight (), queryp , executeCall ));
786+ // Otherwise we use the hijack strategy
787+ var maybeStmtObject = stmtObject .flatMap (this ::validateStatementCreationExprForHijack );
788+ if (maybeStmtObject .isPresent ()) {
789+ return Optional .of (fixByHijackedStatement (maybeStmtObject .get (), queryp , executeCall ));
726790 }
727791 }
728792 }
0 commit comments