diff --git a/release.py b/release.py index 8bbfbabc..8652c23b 100755 --- a/release.py +++ b/release.py @@ -60,6 +60,11 @@ def get(self, key: Literal["auth_info"], default: str | None = None) -> str: ... @overload def get(self, key: Literal["ssh_user"], default: str | None = None) -> str: ... + @overload + def get( + self, key: Literal["ssh_key"], default: str | None = None + ) -> str | None: ... + @overload def get(self, key: Literal["sign_gpg"], default: bool | None = None) -> bool: ... @@ -84,6 +89,9 @@ def __getitem__(self, key: Literal["auth_info"]) -> str: ... @overload def __getitem__(self, key: Literal["ssh_user"]) -> str: ... + @overload + def __getitem__(self, key: Literal["ssh_key"]) -> str | None: ... + @overload def __getitem__(self, key: Literal["sign_gpg"]) -> bool: ... @@ -110,6 +118,9 @@ def __setitem__(self, key: Literal["auth_info"], value: str) -> None: ... @overload def __setitem__(self, key: Literal["ssh_user"], value: str) -> None: ... + @overload + def __setitem__(self, key: Literal["ssh_key"], value: str | None) -> None: ... + @overload def __setitem__(self, key: Literal["sign_gpg"], value: bool) -> None: ... diff --git a/run_release.py b/run_release.py index c4cf0cb6..74b43965 100755 --- a/run_release.py +++ b/run_release.py @@ -221,6 +221,7 @@ def __init__( api_key: str, ssh_user: str, sign_gpg: bool, + ssh_key: str | None = None, first_state: Task | None = None, ) -> None: self.tasks = tasks @@ -243,6 +244,8 @@ def __init__( self.db["auth_info"] = api_key if not self.db.get("ssh_user"): self.db["ssh_user"] = ssh_user + if not self.db.get("ssh_key"): + self.db["ssh_key"] = ssh_key if not self.db.get("sign_gpg"): self.db["sign_gpg"] = sign_gpg @@ -255,6 +258,7 @@ def __init__( print(f"- Normalized release tag: {release_tag.normalized()}") print(f"- Git repo: {self.db['git_repo']}") print(f"- SSH username: {self.db['ssh_user']}") + print(f"- SSH key: {self.db['ssh_key'] or 'Default'}") print(f"- Sign with GPG: {self.db['sign_gpg']}") print() @@ -340,9 +344,13 @@ def check_ssh_connection(db: ReleaseShelf) -> None: client = paramiko.SSHClient() client.load_system_host_keys() client.set_missing_host_key_policy(paramiko.WarningPolicy) - client.connect(DOWNLOADS_SERVER, port=22, username=db["ssh_user"]) + client.connect( + DOWNLOADS_SERVER, port=22, username=db["ssh_user"], key_filename=db["ssh_key"] + ) client.exec_command("pwd") - client.connect(DOCS_SERVER, port=22, username=db["ssh_user"]) + client.connect( + DOCS_SERVER, port=22, username=db["ssh_user"], key_filename=db["ssh_key"] + ) client.exec_command("pwd") @@ -350,7 +358,9 @@ def check_sigstore_client(db: ReleaseShelf) -> None: client = paramiko.SSHClient() client.load_system_host_keys() client.set_missing_host_key_policy(paramiko.WarningPolicy) - client.connect(DOWNLOADS_SERVER, port=22, username=db["ssh_user"]) + client.connect( + DOWNLOADS_SERVER, port=22, username=db["ssh_user"], key_filename=db["ssh_key"] + ) _, stdout, _ = client.exec_command("python3 -m sigstore --version") sigstore_version = stdout.read(1000).decode() sigstore_vermatch = re.match("^sigstore ([0-9.]+)", sigstore_version) @@ -659,7 +669,7 @@ def sign_source_artifacts(db: ReleaseShelf) -> None: subprocess.check_call( [ - "python3", + sys.executable, "-m", "sigstore", "sign", @@ -730,7 +740,7 @@ def upload_files_to_server(db: ReleaseShelf, server: str) -> None: client = paramiko.SSHClient() client.load_system_host_keys() client.set_missing_host_key_policy(paramiko.WarningPolicy) - client.connect(server, port=22, username=db["ssh_user"]) + client.connect(server, port=22, username=db["ssh_user"], key_filename=db["ssh_key"]) transport = client.get_transport() assert transport is not None, f"SSH transport to {server} is None" @@ -775,7 +785,9 @@ def place_files_in_download_folder(db: ReleaseShelf) -> None: client = paramiko.SSHClient() client.load_system_host_keys() client.set_missing_host_key_policy(paramiko.WarningPolicy) - client.connect(DOWNLOADS_SERVER, port=22, username=db["ssh_user"]) + client.connect( + DOWNLOADS_SERVER, port=22, username=db["ssh_user"], key_filename=db["ssh_key"] + ) transport = client.get_transport() assert transport is not None, f"SSH transport to {DOWNLOADS_SERVER} is None" @@ -826,7 +838,9 @@ def unpack_docs_in_the_docs_server(db: ReleaseShelf) -> None: client = paramiko.SSHClient() client.load_system_host_keys() client.set_missing_host_key_policy(paramiko.WarningPolicy) - client.connect(DOCS_SERVER, port=22, username=db["ssh_user"]) + client.connect( + DOCS_SERVER, port=22, username=db["ssh_user"], key_filename=db["ssh_key"] + ) transport = client.get_transport() assert transport is not None, f"SSH transport to {DOCS_SERVER} is None" @@ -968,7 +982,9 @@ def wait_until_all_files_are_in_folder(db: ReleaseShelf) -> None: client = paramiko.SSHClient() client.load_system_host_keys() client.set_missing_host_key_policy(paramiko.WarningPolicy) - client.connect(DOWNLOADS_SERVER, port=22, username=db["ssh_user"]) + client.connect( + DOWNLOADS_SERVER, port=22, username=db["ssh_user"], key_filename=db["ssh_key"] + ) ftp_client = client.open_sftp() destination = f"/srv/www.python.org/ftp/python/{db['release'].normalized()}" @@ -1006,7 +1022,9 @@ def run_add_to_python_dot_org(db: ReleaseShelf) -> None: client = paramiko.SSHClient() client.load_system_host_keys() client.set_missing_host_key_policy(paramiko.WarningPolicy) - client.connect(DOWNLOADS_SERVER, port=22, username=db["ssh_user"]) + client.connect( + DOWNLOADS_SERVER, port=22, username=db["ssh_user"], key_filename=db["ssh_key"] + ) transport = client.get_transport() assert transport is not None, f"SSH transport to {DOWNLOADS_SERVER} is None" @@ -1344,6 +1362,13 @@ def _api_key(api_key: str) -> str: help="Username to be used when authenticating via ssh", type=str, ) + parser.add_argument( + "--ssh-key", + dest="ssh_key", + default=None, + help="Path to the SSH key file to use for authentication", + type=str, + ) args = parser.parse_args() auth_key = args.auth_key or os.getenv("AUTH_INFO") @@ -1432,6 +1457,7 @@ def _api_key(api_key: str) -> str: api_key=auth_key, ssh_user=args.ssh_user, sign_gpg=not no_gpg, + ssh_key=args.ssh_key, tasks=tasks, ) automata.run()