diff --git a/patchwork/steps/PR/typed.py b/patchwork/steps/PR/typed.py index cec3e3f2a..fc543e3ef 100644 --- a/patchwork/steps/PR/typed.py +++ b/patchwork/steps/PR/typed.py @@ -31,6 +31,7 @@ class PRInputs(TypedDict, total=False): scm_url: Annotated[str, StepTypeConfig(is_config=True)] gitlab_api_key: Annotated[str, StepTypeConfig(is_config=True)] github_api_key: Annotated[str, StepTypeConfig(is_config=True)] + issue_url: Annotated[str, StepTypeConfig(is_config=True)] class PROutputs(TypedDict): diff --git a/patchwork/steps/PreparePR/PreparePR.py b/patchwork/steps/PreparePR/PreparePR.py index 288557d2f..6203fb3d8 100644 --- a/patchwork/steps/PreparePR/PreparePR.py +++ b/patchwork/steps/PreparePR/PreparePR.py @@ -3,23 +3,23 @@ from patchwork.logger import logger from patchwork.step import Step, StepStatus +from patchwork.steps.PreparePR.typed import PreparePRInputs, PreparePROutputs -class PreparePR(Step): - required_keys = {"modified_code_files"} +class PreparePR(Step, input_class=PreparePRInputs, output_class=PreparePROutputs): def __init__(self, inputs: dict): super().__init__(inputs) - if not all(key in inputs.keys() for key in self.required_keys): - raise ValueError(f'Missing required data: "{self.required_keys}"') - - if len(inputs["modified_code_files"]) < 1: - logger.warning("No modified files to prepare a PR for.") self.modified_code_files = inputs["modified_code_files"] + if len(self.modified_code_files) < 1: + logger.warning("No modified files to prepare a PR for.") - self.header = f"This pull request from patched fixes {len(self.modified_code_files)} issues." - if "pr_header" in inputs.keys(): - self.header = inputs["pr_header"] + issue_url = inputs.get("issue_url") + self.header = inputs.get("pr_header") + if self.header is None and issue_url is None: + self.header = f"This pull request from patched fixes {len(self.modified_code_files)} issues." + elif self.header is None and issue_url is not None: + self.header = f"This pull request from patched fixes [issue]({issue_url})." def run(self) -> dict: if len(self.modified_code_files) == 0: diff --git a/patchwork/steps/PreparePR/typed.py b/patchwork/steps/PreparePR/typed.py index eda835196..c60a3d66e 100644 --- a/patchwork/steps/PreparePR/typed.py +++ b/patchwork/steps/PreparePR/typed.py @@ -9,6 +9,7 @@ class __PreparePRRequiredInputs(TypedDict): class PreparePRInputs(__PreparePRRequiredInputs, total=False): pr_header: Annotated[str, StepTypeConfig(is_config=True)] + issue_url: Annotated[str, StepTypeConfig(is_config=True)] class PreparePROutputs(TypedDict): diff --git a/tests/steps/test_PreparePR.py b/tests/steps/test_PreparePR.py index 864d2cd4e..c53bc745a 100644 --- a/tests/steps/test_PreparePR.py +++ b/tests/steps/test_PreparePR.py @@ -15,10 +15,6 @@ def prepare_pr_instance(): return PreparePR(inputs) -def test_init_required_keys(prepare_pr_instance): - assert prepare_pr_instance.required_keys == {"modified_code_files"} - - def test_init_inputs(prepare_pr_instance): assert prepare_pr_instance.modified_code_files == [ {"path": "file1", "start_line": 1, "end_line": 2, "commit_message": "commit msg"},