1818import io .codemodder .remediation .FixCandidate ;
1919import io .codemodder .remediation .FixCandidateSearchResults ;
2020import io .codemodder .remediation .FixCandidateSearcher ;
21+ import io .codemodder .remediation .MethodOrConstructor ;
2122import io .github .pixee .security .HostValidator ;
2223import io .github .pixee .security .Urls ;
2324import java .util .ArrayList ;
2425import java .util .List ;
26+ import java .util .function .BiPredicate ;
2527import java .util .function .Function ;
28+ import org .javatuples .Pair ;
2629
2730final class DefaultSSRFRemediator implements SSRFRemediator {
2831
@@ -45,8 +48,25 @@ public <T> CodemodFileScanningResult remediateAll(
4548 .withMatcher (mce -> !mce .getArguments ().isEmpty ())
4649 .build ();
4750
48- FixCandidateSearchResults <T > results =
49- searcher .search (
51+ // Matches calls like RestTemplate.exchange(...)
52+ // Doesn't actually check that the `exchange` call scope is actually of type RestTemplate
53+ // This is left for the detectors, for now
54+ FixCandidateSearcher <T > rtSearcher =
55+ new FixCandidateSearcher .Builder <T >()
56+ // is method with name
57+ .withMatcher (mce -> mce .isMethodCallWithName ("exchange" ))
58+ // has a scope
59+ .withMatcher (MethodOrConstructor ::isMethodCallWithScope )
60+ // Could be improved further by adding a RestTemplate type check to the scope
61+ .build ();
62+
63+ List <CodemodChange > changes = new ArrayList <>();
64+ List <UnfixedFinding > unfixedFindings = new ArrayList <>();
65+
66+ var pairResult =
67+ searchAndFix (
68+ searcher ,
69+ (cunit , moc ) -> harden (cunit , moc .asObjectCreationExpr ()),
5070 cu ,
5171 path ,
5272 detectorRule ,
@@ -55,30 +75,28 @@ public <T> CodemodFileScanningResult remediateAll(
5575 getStartLine ,
5676 getEndLine ,
5777 getStartColumn );
78+ changes .addAll (pairResult .getValue0 ());
79+ unfixedFindings .addAll (pairResult .getValue1 ());
5880
59- List <CodemodChange > changes = new ArrayList <>();
60-
61- for (FixCandidate <T > candidate : results .fixCandidates ()) {
62- ObjectCreationExpr call = (ObjectCreationExpr ) candidate .call ().asNode ();
63- List <T > issues = candidate .issues ();
64- harden (cu , call );
65- List <FixedFinding > fixedFindings =
66- issues .stream ()
67- .map (issue -> new FixedFinding (getKey .apply (issue ), detectorRule ))
68- .toList ();
69- CodemodChange change =
70- CodemodChange .from (
71- getStartLine .apply (issues .get (0 )),
72- List .of (DependencyGAV .JAVA_SECURITY_TOOLKIT ),
73- fixedFindings );
74- changes .add (change );
75- }
81+ var pairResultRT =
82+ searchAndFix (
83+ rtSearcher ,
84+ (cunit , moc ) -> hardenRT (cunit , moc .asMethodCall ()),
85+ cu ,
86+ path ,
87+ detectorRule ,
88+ issuesForFile ,
89+ getKey ,
90+ getStartLine ,
91+ getEndLine ,
92+ getStartColumn );
93+ changes .addAll (pairResultRT .getValue0 ());
94+ unfixedFindings .addAll (pairResultRT .getValue1 ());
7695
77- List <UnfixedFinding > unfixedFindings = new ArrayList <>(results .unfixableFindings ());
7896 return CodemodFileScanningResult .from (changes , unfixedFindings );
7997 }
8098
81- private void harden (final CompilationUnit cu , final ObjectCreationExpr newUrlCall ) {
99+ private boolean harden (final CompilationUnit cu , final ObjectCreationExpr newUrlCall ) {
82100 NodeList <Expression > arguments = newUrlCall .getArguments ();
83101
84102 /*
@@ -92,23 +110,102 @@ private void harden(final CompilationUnit cu, final ObjectCreationExpr newUrlCal
92110 * ...
93111 * URL u = Urls.create(foo, io.github.pixee.security.Urls.HTTP_PROTOCOLS, io.github.pixee.security.HostValidator.ALLOW_ALL)
94112 */
113+ MethodCallExpr safeCall = wrapInUrlsCreate (cu , arguments );
114+ newUrlCall .replace (safeCall );
115+ return true ;
116+ }
117+
118+ private MethodCallExpr wrapInUrlsCreate (
119+ final CompilationUnit cu , final NodeList <Expression > arguments ) {
95120 addImportIfMissing (cu , Urls .class .getName ());
96121 addImportIfMissing (cu , HostValidator .class .getName ());
122+
97123 FieldAccessExpr httpProtocolsExpr = new FieldAccessExpr ();
98124 httpProtocolsExpr .setScope (new NameExpr (Urls .class .getSimpleName ()));
99125 httpProtocolsExpr .setName ("HTTP_PROTOCOLS" );
100126
101127 FieldAccessExpr denyCommonTargetsExpr = new FieldAccessExpr ();
102-
103128 denyCommonTargetsExpr .setScope (new NameExpr (HostValidator .class .getSimpleName ()));
104129 denyCommonTargetsExpr .setName ("DENY_COMMON_INFRASTRUCTURE_TARGETS" );
105130
106131 NodeList <Expression > newArguments = new NodeList <>();
107- newArguments .addAll (arguments ); // first are all the arguments they were passing to "new URL"
132+ newArguments .addAll (arguments ); // add expression
108133 newArguments .add (httpProtocolsExpr ); // load the protocols they're allowed
109134 newArguments .add (denyCommonTargetsExpr ); // load the host validator
110- MethodCallExpr safeCall =
111- new MethodCallExpr (new NameExpr (Urls .class .getSimpleName ()), "create" , newArguments );
112- newUrlCall .replace (safeCall );
135+
136+ return new MethodCallExpr (new NameExpr (Urls .class .getSimpleName ()), "create" , newArguments );
137+ }
138+
139+ private boolean hardenRT (final CompilationUnit cu , final MethodCallExpr call ) {
140+ var maybeFirstArg = call .getArguments ().stream ().findFirst ();
141+ if (maybeFirstArg .isPresent ()) {
142+ var wrappedArg =
143+ new MethodCallExpr (
144+ wrapInUrlsCreate (cu , new NodeList <>(maybeFirstArg .get ().clone ())), "toString" );
145+ maybeFirstArg .get ().replace (wrappedArg );
146+ return true ;
147+ }
148+ return false ;
149+ }
150+
151+ /**
152+ * Returns a list of changes and unfixed findings for a pair of searcher, that gather relevant
153+ * issues, and a fixer predicate, that returns true if the change is successful.
154+ */
155+ private <T > Pair <List <CodemodChange >, List <UnfixedFinding >> searchAndFix (
156+ final FixCandidateSearcher <T > searcher ,
157+ final BiPredicate <CompilationUnit , MethodOrConstructor > fixer ,
158+ final CompilationUnit cu ,
159+ final String path ,
160+ final DetectorRule detectorRule ,
161+ final List <T > issuesForFile ,
162+ final Function <T , String > getKey ,
163+ final Function <T , Integer > getStartLine ,
164+ final Function <T , Integer > getEndLine ,
165+ final Function <T , Integer > getStartColumn ) {
166+ List <CodemodChange > changes = new ArrayList <>();
167+ List <UnfixedFinding > unfixedFindings = new ArrayList <>();
168+
169+ FixCandidateSearchResults <T > results =
170+ searcher .search (
171+ cu ,
172+ path ,
173+ detectorRule ,
174+ issuesForFile ,
175+ getKey ,
176+ getStartLine ,
177+ getEndLine ,
178+ getStartColumn );
179+
180+ for (FixCandidate <T > candidate : results .fixCandidates ()) {
181+ MethodOrConstructor call = candidate .call ();
182+ List <T > issues = candidate .issues ();
183+ if (fixer .test (cu , call )) {
184+ List <FixedFinding > fixedFindings =
185+ issues .stream ()
186+ .map (issue -> new FixedFinding (getKey .apply (issue ), detectorRule ))
187+ .toList ();
188+ CodemodChange change =
189+ CodemodChange .from (
190+ getStartLine .apply (issues .get (0 )),
191+ List .of (DependencyGAV .JAVA_SECURITY_TOOLKIT ),
192+ fixedFindings );
193+ changes .add (change );
194+ } else {
195+ issues .forEach (
196+ issue -> {
197+ final String id = getKey .apply (issue );
198+ final UnfixedFinding unfixableFinding =
199+ new UnfixedFinding (
200+ id ,
201+ detectorRule ,
202+ path ,
203+ getStartLine .apply (issues .get (0 )),
204+ "State changing effects possible or unrecognized code shape" );
205+ unfixedFindings .add (unfixableFinding );
206+ });
207+ }
208+ }
209+ return Pair .with (changes , unfixedFindings );
113210 }
114211}
0 commit comments