Skip to content

Commit e943f3c

Browse files
committed
Refactored hardening and remediation behavior
1 parent c73ba1b commit e943f3c

File tree

2 files changed

+104
-69
lines changed

2 files changed

+104
-69
lines changed

src/codemodder/codemods/base_codemod.py

Lines changed: 95 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -185,12 +185,98 @@ def get_files_to_analyze(
185185
"""
186186
...
187187

188-
def _apply(
188+
def _apply_remediation(
189189
self,
190190
context: CodemodExecutionContext,
191191
rules: list[str],
192-
remediation: bool,
193192
) -> None | TokenUsage:
193+
"""
194+
Applies remediation behavior to a codemod, that is, each changeset will only be associated with a single finging and no files will be written.
195+
"""
196+
results: ResultSet | None = self._apply_detector(context)
197+
198+
if results is not None and not results:
199+
logger.debug("No results for %s", self.id)
200+
return None
201+
202+
if not (files_to_analyze := self.get_files_to_analyze(context, results)):
203+
logger.debug("No files matched for %s", self.id)
204+
return None
205+
206+
# Do each result independently and outputs the diffs
207+
# gather positional arguments for the map
208+
resultset_arguments: list[ResultSet | None] = []
209+
path_arguments = []
210+
if results:
211+
for result in results.results_for_rules(rules):
212+
# this need to be the same type of ResultSet as results
213+
singleton = results.__class__()
214+
singleton.add_result(result)
215+
result_locations = self.get_files_to_analyze(context, singleton)
216+
# We do an execution for each location in the result
217+
# So we duplicate the resultset argument for each location
218+
for loc in result_locations:
219+
resultset_arguments.append(singleton)
220+
path_arguments.append(loc)
221+
# An exception for find-and-fix codemods
222+
else:
223+
resultset_arguments = [None]
224+
path_arguments = files_to_analyze
225+
226+
contexts: list = []
227+
with ThreadPoolExecutor() as executor:
228+
logger.debug("using executor with %s workers", context.max_workers)
229+
contexts.extend(
230+
executor.map(
231+
lambda path, resultset: self._process_file(
232+
path, context, resultset, rules
233+
),
234+
path_arguments,
235+
resultset_arguments or [None],
236+
)
237+
)
238+
executor.shutdown(wait=True)
239+
240+
context.process_results(self.id, contexts)
241+
return None
242+
243+
def _apply_hardening(
244+
self,
245+
context: CodemodExecutionContext,
246+
rules: list[str],
247+
) -> None | TokenUsage:
248+
"""
249+
Applies hardening behavior to a codemod with the goal of integrating all fixes for each finding into the files.
250+
"""
251+
results: ResultSet | None = self._apply_detector(context)
252+
253+
if results is not None and not results:
254+
logger.debug("No results for %s", self.id)
255+
return None
256+
257+
if not (files_to_analyze := self.get_files_to_analyze(context, results)):
258+
logger.debug("No files matched for %s", self.id)
259+
return None
260+
261+
# Hardens all findings per file at once and writes the fixed code into the file
262+
process_file = functools.partial(
263+
self._process_file, context=context, results=results, rules=rules
264+
)
265+
266+
contexts = []
267+
if context.max_workers == 1:
268+
logger.debug("processing files serially")
269+
contexts.extend([process_file(file) for file in files_to_analyze])
270+
else:
271+
with ThreadPoolExecutor() as executor:
272+
logger.debug("using executor with %s workers", context.max_workers)
273+
contexts.extend(executor.map(process_file, files_to_analyze))
274+
executor.shutdown(wait=True)
275+
276+
context.process_results(self.id, contexts)
277+
return None
278+
279+
def _apply_detector(self, context: CodemodExecutionContext) -> ResultSet | None:
194280
if self.provider and (
195281
not (provider := context.providers.get_provider(self.provider))
196282
or not provider.is_available
@@ -219,68 +305,7 @@ def _apply(
219305
else None
220306
)
221307

222-
if results is not None and not results:
223-
logger.debug("No results for %s", self.id)
224-
return None
225-
226-
if not (files_to_analyze := self.get_files_to_analyze(context, results)):
227-
logger.debug("No files matched for %s", self.id)
228-
return None
229-
230-
# Do each result independently and outputs the diffs
231-
if remediation:
232-
# gather positional arguments for the map
233-
resultset_arguments: list[ResultSet | None] = []
234-
path_arguments = []
235-
if results:
236-
for result in results.results_for_rules(rules):
237-
# this need to be the same type of ResultSet as results
238-
singleton = results.__class__()
239-
singleton.add_result(result)
240-
result_locations = self.get_files_to_analyze(context, singleton)
241-
# We do an execution for each location in the result
242-
# So we duplicate the resultset argument for each location
243-
for loc in result_locations:
244-
resultset_arguments.append(singleton)
245-
path_arguments.append(loc)
246-
# An exception for find-and-fix codemods
247-
else:
248-
resultset_arguments = [None]
249-
path_arguments = files_to_analyze
250-
251-
contexts: list = []
252-
with ThreadPoolExecutor() as executor:
253-
logger.debug("using executor with %s workers", context.max_workers)
254-
contexts.extend(
255-
executor.map(
256-
lambda path, resultset: self._process_file(
257-
path, context, resultset, rules
258-
),
259-
path_arguments,
260-
resultset_arguments or [None],
261-
)
262-
)
263-
executor.shutdown(wait=True)
264-
265-
context.process_results(self.id, contexts)
266-
# Hardens all findings per file at once and writes the fixed code into the file
267-
else:
268-
process_file = functools.partial(
269-
self._process_file, context=context, results=results, rules=rules
270-
)
271-
272-
contexts = []
273-
if context.max_workers == 1:
274-
logger.debug("processing files serially")
275-
contexts.extend([process_file(file) for file in files_to_analyze])
276-
else:
277-
with ThreadPoolExecutor() as executor:
278-
logger.debug("using executor with %s workers", context.max_workers)
279-
contexts.extend(executor.map(process_file, files_to_analyze))
280-
executor.shutdown(wait=True)
281-
282-
context.process_results(self.id, contexts)
283-
return None
308+
return results
284309

285310
def apply(
286311
self, context: CodemodExecutionContext, remediation: bool = False
@@ -300,7 +325,9 @@ def apply(
300325
301326
:param context: The codemod execution context
302327
"""
303-
return self._apply(context, [self._internal_name], remediation)
328+
if remediation:
329+
return self._apply_remediation(context, [self._internal_name])
330+
return self._apply_hardening(context, [self._internal_name])
304331

305332
def _process_file(
306333
self,
@@ -401,7 +428,9 @@ def __init__(
401428
def apply(
402429
self, context: CodemodExecutionContext, remediation: bool = False
403430
) -> None | TokenUsage:
404-
return self._apply(context, self.requested_rules, remediation)
431+
if remediation:
432+
return self._apply_remediation(context, self.requested_rules)
433+
return self._apply_hardening(context, self.requested_rules)
405434

406435
def get_files_to_analyze(
407436
self,

src/core_codemods/defectdojo/api.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from codemodder.codemods.api import Metadata, Reference, ToolMetadata, ToolRule
99
from codemodder.codemods.base_detector import BaseDetector
1010
from codemodder.context import CodemodExecutionContext
11+
from codemodder.llm import TokenUsage
1112
from codemodder.result import ResultSet
1213
from core_codemods.api import CoreCodemod, SASTCodemod
1314

@@ -77,10 +78,15 @@ def apply(
7778
self,
7879
context: CodemodExecutionContext,
7980
remediation: bool = False,
80-
) -> None:
81-
self._apply(
81+
) -> None | TokenUsage:
82+
if remediation:
83+
return self._apply_remediation(
84+
context,
85+
# We know this has a tool because we created it with `from_core_codemod`
86+
cast(ToolMetadata, self._metadata.tool).rule_ids,
87+
)
88+
return self._apply_hardening(
8289
context,
8390
# We know this has a tool because we created it with `from_core_codemod`
8491
cast(ToolMetadata, self._metadata.tool).rule_ids,
85-
remediation,
8692
)

0 commit comments

Comments
 (0)