diff --git a/Cargo.lock b/Cargo.lock index 00a340ae..4ad1da47 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1474,7 +1474,6 @@ dependencies = [ "git-url-parse", "git2", "hub_client", - "openssh", "progress_tracking", "reqwest", "reqwest-middleware", @@ -1482,6 +1481,7 @@ dependencies = [ "serde", "serde_json", "serial_test", + "shell-words", "tempfile", "thiserror 2.0.12", "tokio", @@ -2542,20 +2542,6 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4ce411919553d3f9fa53a0880544cda985a112117a0444d5ff1e870a893d6ea" -[[package]] -name = "openssh" -version = "0.11.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea0bb128ba90e86bc55dae66031935f361cda4cbc1f011547c55a7d80079bc3e" -dependencies = [ - "libc", - "once_cell", - "shell-escape", - "tempfile", - "thiserror 2.0.12", - "tokio", -] - [[package]] name = "openssl" version = "0.10.72" @@ -3615,10 +3601,10 @@ dependencies = [ ] [[package]] -name = "shell-escape" -version = "0.1.5" +name = "shell-words" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45bb67a18fa91266cc7807181f62f9178a6873bfad7dc788c42e6430db40184f" +checksum = "24188a676b6ae68c3b2cb3a01be17fbf7240ce009799bb56d5b1409051e78fde" [[package]] name = "shellexpand" diff --git a/Cargo.toml b/Cargo.toml index a6fef7f3..e9f3146f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -71,7 +71,6 @@ mockall = "0.13" more-asserts = "0.3" once_cell = "1.20" oneshot = "0.1" -openssh = "0.11" pin-project = "1" prometheus = "0.14" rand = "0.9" @@ -93,6 +92,7 @@ serde = { version = "1", features = ["derive"] } serde_json = "1" serde_repr = "0.1" sha2 = "0.10" +shell-words = "1.1" shellexpand = "3.1" static_assertions = "1.1" sysinfo = "0.37" diff --git a/git_xet/Cargo.toml b/git_xet/Cargo.toml index 10c6f333..b23ffbfa 100644 --- a/git_xet/Cargo.toml +++ b/git_xet/Cargo.toml @@ -26,12 +26,10 @@ reqwest-middleware = { workspace = true } rust-netrc = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } +shell-words = { workspace = true } tempfile = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true } -[target.'cfg(unix)'.dependencies] -openssh = { workspace = true } - [dev-dependencies] serial_test = { workspace = true } \ No newline at end of file diff --git a/git_xet/src/auth.rs b/git_xet/src/auth.rs index c8f88a46..3847e8f4 100644 --- a/git_xet/src/auth.rs +++ b/git_xet/src/auth.rs @@ -127,7 +127,7 @@ pub fn get_credential(repo: &GitRepo, remote_url: &GitUrl, operation: Operation) // 5. check remote URL scheme if matches!(remote_url.scheme(), Scheme::Ssh | Scheme::GitSsh) { #[cfg(unix)] - return Ok(SSHCredentialHelper::new(remote_url, operation)); + return Ok(SSHCredentialHelper::new(remote_url, repo, operation)); #[cfg(not(unix))] return Err(GitXetError::not_supported(format!( "using {} in a repository with SSH Git URL is under development; please check back for diff --git a/git_xet/src/auth/ssh.rs b/git_xet/src/auth/ssh.rs index e672248f..ca8631c2 100644 --- a/git_xet/src/auth/ssh.rs +++ b/git_xet/src/auth/ssh.rs @@ -1,14 +1,16 @@ use std::sync::Arc; use async_trait::async_trait; -use hub_client::{CredentialHelper, HubClientError, Operation, Result}; -#[cfg(unix)] -use openssh::{KnownHosts, Session}; +use hub_client::{CredentialHelper, Operation}; use reqwest::header; use reqwest_middleware::RequestBuilder; use serde::Deserialize; +use crate::errors::{GitXetError, Result}; +use crate::git_repo::GitRepo; use crate::git_url::GitUrl; +use crate::utils::process_wrapping::run_program_captured_with_input_and_output; +use crate::utils::ssh_connect::{SSHMetadata, get_sshcmd_and_args}; #[derive(Deserialize)] struct GitLFSAuthentationResponseHeader { @@ -32,39 +34,38 @@ struct GitLFSAuthenticateResponse { // it has a shorter TTL than that of a Xet CAS JWT. pub struct SSHCredentialHelper { remote_url: GitUrl, + repo: GitRepo, operation: Operation, } impl SSHCredentialHelper { - pub fn new(remote_url: &GitUrl, operation: Operation) -> Arc { + pub fn new(remote_url: &GitUrl, repo: &GitRepo, operation: Operation) -> Arc { Arc::new(Self { remote_url: remote_url.clone(), + repo: repo.clone(), operation, }) } - #[cfg(unix)] async fn authenticate(&self) -> Result { - let host_url = self.remote_url.host_url().map_err(HubClientError::credential_helper_error)?; - let full_repo_path = self.remote_url.full_repo_path(); - let session = Session::connect(&host_url, KnownHosts::Add) - .await - .map_err(HubClientError::credential_helper_error)?; - - let output = session - .command("git-lfs-authenticate") - .arg(full_repo_path) - .arg(self.operation.as_str()) - .output() - .await - .map_err(HubClientError::credential_helper_error)?; - - serde_json::from_slice(&output.stdout).map_err(HubClientError::credential_helper_error) - } + let meta = SSHMetadata { + user_and_host: self.remote_url.user_and_host()?, + port: self.remote_url.port(), + arg_list: vec![ + "git-lfs-authenticate".into(), + self.remote_url.full_repo_path(), + self.operation.as_str().into(), + ], + }; - #[cfg(not(unix))] - async fn authenticate(&self) -> Result { - unimplemented!() + let (program, args) = get_sshcmd_and_args(&meta, &self.repo)?; + + let (output, _err) = + run_program_captured_with_input_and_output(program, self.repo.git_path()?, args)?.wait_with_output()?; + + let response: GitLFSAuthenticateResponse = serde_json::from_slice(&output).map_err(GitXetError::internal)?; + + Ok(response) } } @@ -86,14 +87,18 @@ mod tests { use hub_client::Operation; use super::SSHCredentialHelper; + use crate::git_repo::GitRepo; use crate::git_url::GitUrl; + use crate::test_utils::TestRepo; #[tokio::test] #[ignore = "need ssh server"] async fn test_ssh_cred_helper_local() -> Result<()> { + let test_repo = TestRepo::new("main")?; + let repo = GitRepo::open(test_repo.path())?; let remote_url = "ssh://git@localhost:2222/datasets/test/td"; let parsed_url: GitUrl = remote_url.parse()?; - let ssh_helper = SSHCredentialHelper::new(&parsed_url, Operation::Download); + let ssh_helper = SSHCredentialHelper::new(&parsed_url, &repo, Operation::Download); let response = ssh_helper.authenticate().await?; @@ -107,9 +112,11 @@ mod tests { #[tokio::test] #[ignore = "need ssh key"] async fn test_ssh_cred_helper_remote() -> Result<()> { + let test_repo = TestRepo::new("main")?; + let repo = GitRepo::open(test_repo.path())?; let remote_url = "ssh://git@hf.co/seanses/tm"; // it seems that ssh port is not open on "huggingface.co" let parsed_url: GitUrl = remote_url.parse()?; - let ssh_helper = SSHCredentialHelper::new(&parsed_url, Operation::Upload); + let ssh_helper = SSHCredentialHelper::new(&parsed_url, &repo, Operation::Upload); let response = ssh_helper.authenticate().await?; diff --git a/git_xet/src/git_url.rs b/git_xet/src/git_url.rs index 13673e26..d5dd402f 100644 --- a/git_xet/src/git_url.rs +++ b/git_xet/src/git_url.rs @@ -124,7 +124,8 @@ impl GitUrl { } // Returns the front part of the Git remote URL removing repo path, - // e.g. (scheme://)(auth)[host_name](:port) + // e.g. (scheme://)(auth@)[host_name](:port) + #[allow(unused)] pub fn host_url(&self) -> Result { let scheme_str = if self.inner.scheme_prefix { format!("{}://", self.inner.scheme) @@ -159,6 +160,29 @@ impl GitUrl { Ok(format!("{scheme_str}{auth_str}{host}{port_str}")) } + // Returns the user and host in the format of (auth@)[host]. + pub fn user_and_host(&self) -> Result { + let auth_str = match self.inner.scheme { + Scheme::Http | Scheme::Https | Scheme::Ssh | Scheme::GitSsh => { + match (&self.inner.user, &self.inner.token) { + (Some(user), Some(token)) => format!("{}:{}@", user, token), + (Some(user), None) => format!("{}@", user), + (None, Some(token)) => format!("{}@", token), + (None, None) => "".to_owned(), + } + }, + _ => "".to_owned(), + }; + + let host = self + .inner + .host + .as_ref() + .ok_or_else(|| GitXetError::config_error("remote URL missing host name"))?; + + Ok(format!("{auth_str}{host}")) + } + pub fn port(&self) -> Option { self.inner.port } diff --git a/git_xet/src/utils/mod.rs b/git_xet/src/utils/mod.rs index 6acb3f2e..5426b76f 100644 --- a/git_xet/src/utils/mod.rs +++ b/git_xet/src/utils/mod.rs @@ -1 +1,2 @@ pub mod process_wrapping; +pub mod ssh_connect; diff --git a/git_xet/src/utils/process_wrapping.rs b/git_xet/src/utils/process_wrapping.rs index 43eff67d..3986622d 100644 --- a/git_xet/src/utils/process_wrapping.rs +++ b/git_xet/src/utils/process_wrapping.rs @@ -33,7 +33,7 @@ where /// Return `Ok(())` if the command finishes correctly and the child's stdout and stderr are ignored; /// Return the underlying I/O error if the child process spawning or waiting fails; otherwise, the captured /// stdout and stderr of the child are wrapped in an `Err(GitXetError::CommandFailed(_))` and returned. -#[allow(dead_code)] +#[allow(unused)] pub fn run_program_captured(program: S1, working_dir: P, args: I) -> Result<()> where S1: AsRef, @@ -98,7 +98,6 @@ where /// /// let (response, _err) = cmd.wait_with_output()?; /// ``` -#[allow(dead_code)] pub fn run_program_captured_with_input_and_output( program: S1, working_dir: P, diff --git a/git_xet/src/utils/ssh_connect.rs b/git_xet/src/utils/ssh_connect.rs new file mode 100644 index 00000000..0de7d036 --- /dev/null +++ b/git_xet/src/utils/ssh_connect.rs @@ -0,0 +1,522 @@ +// A utility to help establish SSH connection to a remote Git server. + +use crate::errors::{GitXetError, Result}; +use crate::git_repo::GitRepo; + +// The type of SSH program to use for SSH connections, valid values are given +// at https://git-scm.com/docs/git-config#Documentation/git-config.txt-sshvariant. +enum Variant { + Auto, + Ssh, + Simple, + Putty, + Tortoise, +} + +impl From<&str> for Variant { + fn from(value: &str) -> Variant { + match value.to_ascii_lowercase().as_str() { + "" | "auto" => Variant::Auto, + "simple" => Variant::Simple, + "putty" | "plink" => Variant::Putty, + "tortoiseplink" => Variant::Tortoise, + _ => Variant::Ssh, + } + } +} + +pub struct SSHMetadata { + pub user_and_host: String, + pub port: Option, + pub arg_list: Vec, +} + +const DEFAULT_SSH_CMD: &str = "ssh"; +const DEFAULT_GIT_BASH: &str = "sh"; + +// Return the executable name to execute an SSH command on this machine and the base args. +// Format the args as needed if the command needs to run in a shell. +// +// Git allows using a list of environment variables and git configs to control how to establish +// an SSH connection to the remote server. +// +// 1. Env vars $GIT_SSH_COMMAND and $GIT_SSH and git config entry "core.sshCommand" define +// which ssh executable to use for SSH connection. $GIT_SSH_COMMAND takes precedence over "core.sshCommand" +// and both are interpreted by the shell, which allows additional arguments to be included. They both takes +// precedence over $GIT_SSH, which on the other hand must be just the path to a program (which can be a wrapper +// shell script, if additional arguments are needed). +// +// 2. Env var $GIT_SSH_VARIANT takes precedence over git config entry "ssh.variant" and they both define whether +// $GIT_SSH/$GIT_SSH_COMMAND/core.sshCommand refer to OpenSSH, plink/putty or tortoiseplink, or instruct git to +// automatically detect the ssh program type. Valid values are "ssh" (to use OpenSSH options), "plink", "putty", +// "tortoiseplink", "simple" (no options except the host and remote command). The default auto-detection can be +// explicitly requested using the value "auto". Any other value is treated as "ssh". +// +// This implementation follows how the same functionality is handled in +// git-lfs (https://github.com/git-lfs/git-lfs/blob/071e19e8ea03b1e40b181706909fb8c18d928e29/ssh/ssh.go#L41). +pub fn get_sshcmd_and_args(meta: &SSHMetadata, repo: &GitRepo) -> Result<(String, Vec)> { + let (cmd, args, need_shell) = get_sshexe_and_args(meta, repo)?; + + if !need_shell { + Ok((cmd, args)) + } else { + format_for_shell_execution(&cmd, &args) + } +} + +// Return the executable name for ssh on this machine and the base args. +// Base args includes port settings, user/host, everything pre the command to execute. +// +// This implementation follows how the same functionality is handled in +// git (https://github.com/git/git/blob/dc70283dfcdc420d330547fc1d3cba0d29bfd2d0/connect.c#L1367) and +// git-lfs (https://github.com/git-lfs/git-lfs/blob/071e19e8ea03b1e40b181706909fb8c18d928e29/ssh/ssh.go#L127). +pub fn get_sshexe_and_args(meta: &SSHMetadata, repo: &GitRepo) -> Result<(String, Vec, bool)> { + let repo_config = repo.config()?; + + let sshexe = std::env::var("GIT_SSH").unwrap_or_default(); + let ssh_cmd = std::env::var("GIT_SSH_COMMAND").unwrap_or_default(); + + let (mut sshexe, mut cmd, mut need_shell) = parse_shell_command(&ssh_cmd, &sshexe); + if sshexe.is_empty() { + let ssh_cmd = repo_config.get_string("core.sshcommand").unwrap_or_default(); + (sshexe, cmd, need_shell) = parse_shell_command(&ssh_cmd, DEFAULT_SSH_CMD); + } + + let variant = get_ssh_variant(&repo_config, &sshexe); + + if matches!(variant, Variant::Simple) { + return Err(GitXetError::not_supported( + "unable to construct an ssh command using an ssh program of \"simple\" variant. Please + use an advanced ssh program and update environment variables \"GIT_SSH_COMMAND\", \"GIT_SSH\", + \"GIT_SSH_VARIANT\" and git config entries \"core.sshCommand\" and \"ssh.variant\" accordingly. + For details, see https://git-scm.com/docs/git-config#Documentation/git-config.txt-sshvariant. + ", + )); + } + + if cmd.is_empty() { + cmd = sshexe; + } + + let mut args = Vec::::new(); + + if matches!(variant, Variant::Tortoise) { + // TortoisePlink requires the -batch argument to behave like ssh/plink + args.push("-batch".into()); + } + + if let Some(p) = meta.port { + if matches!(variant, Variant::Putty | Variant::Tortoise) { + args.push("-P".into()); + } else { + args.push("-p".into()); + } + args.push(p.to_string()); + } + + args.push(meta.user_and_host.clone()); + args.extend_from_slice(&meta.arg_list); + + Ok((cmd, args, need_shell)) +} + +// Parse command, and if it looks like a valid command, return the ssh executable +// name, the command to run, and whether we need a shell. If not, return +// existing as the ssh binary name. +fn parse_shell_command(command: &str, existing: &str) -> (String, String, bool) { + let parsed_command = shell_words::split(command); + // Is it a valid command? + if let Ok(mut p) = parsed_command + && !p.is_empty() + { + // We don't need the rest of the parsed result, so do a quick removal + // that doesn't preserve the elements order. + (p.swap_remove(0), command.into(), true) + } else { + (existing.into(), "".into(), false) + } +} + +// Find out which type of SSH program is used, this allows constructing the call args accordingly. +// See https://git-scm.com/docs/git-config#Documentation/git-config.txt-sshvariant for details. +fn get_ssh_variant(repo_config: &git2::Config, sshexe: &str) -> Variant { + let variant_str = std::env::var("GIT_SSH_VARIANT") + .or_else(|_| repo_config.get_string("ssh.variant")) + .unwrap_or_default(); + let variant = Variant::from(variant_str.as_str()); + + if matches!(variant, Variant::Auto) + && let Some(base_exe_name) = std::path::Path::new(sshexe).file_stem() + { + match base_exe_name.to_ascii_lowercase().to_str() { + Some("plink") => return Variant::Putty, + Some("tortoiseplink") => return Variant::Tortoise, + _ => (), + } + + Variant::Ssh + } else { + variant + } +} + +// Format a shell command and a subsequent list of args to a syntax correct command +// to be executed by a shell. +// Return the executable name and the args to execute this shell command. +fn format_for_shell_execution(command: &str, args: &[String]) -> Result<(String, Vec)> { + let parsed_command = shell_words::split(command) + .map_err(|e| GitXetError::config_error(format!("parsing ssh command failed with {e}")))?; + let complete_shell_command = shell_words::join(parsed_command.iter().chain(args.iter())); + + Ok((DEFAULT_GIT_BASH.into(), vec!["-c".into(), complete_shell_command])) +} + +#[cfg(test)] +mod tests { + use anyhow::{Ok, Result}; + use serial_test::serial; + use utils::EnvVarGuard; + + use super::*; + use crate::git_repo::GitRepo; + use crate::test_utils::TestRepo; + + #[test] + #[serial(env_var_write_read)] + fn test_get_ssh_variant_explicit() -> Result<()> { + let test_repo = TestRepo::new("main")?; + let repo = GitRepo::open(test_repo.path())?; + + { + let _env = EnvVarGuard::set("GIT_SSH_VARIANT", "putty"); + let repo_config = repo.config()?; + let variant = get_ssh_variant(&repo_config, "/usr/bin/plink"); + assert!(matches!(variant, Variant::Putty)); + } + + { + let _env = EnvVarGuard::set("GIT_SSH_VARIANT", "putty"); + let repo_config = repo.config()?; + let variant = get_ssh_variant(&repo_config, r#"C:\Program Files\Putty\plink.exe"#); + assert!(matches!(variant, Variant::Putty)); + } + + { + test_repo.set_config("ssh.variant", "ssh")?; + let repo_config = repo.config()?; + let variant = get_ssh_variant(&repo_config, "openssh"); + assert!(matches!(variant, Variant::Ssh)); + } + + { + test_repo.set_config("ssh.variant", "openssh")?; + let repo_config = repo.config()?; + let variant = get_ssh_variant(&repo_config, "Openssh.exe"); + assert!(matches!(variant, Variant::Ssh)); + } + + Ok(()) + } + + #[test] + #[serial(env_var_write_read)] + fn test_get_ssh_variant_explicit_override() -> Result<()> { + let test_repo = TestRepo::new("main")?; + let repo = GitRepo::open(test_repo.path())?; + + { + // $GIT_SSH_VARIANT (putty) overrides "ssh.variant" (ssh). + let _env = EnvVarGuard::set("GIT_SSH_VARIANT", "putty"); + test_repo.set_config("ssh.variant", "ssh")?; + let repo_config = repo.config()?; + let variant = get_ssh_variant(&repo_config, "/usr/bin/plink"); + assert!(matches!(variant, Variant::Putty)); + } + + { + // $GIT_SSH_VARIANT (simple) overrides "ssh.variant" (auto). + let _env = EnvVarGuard::set("GIT_SSH_VARIANT", "simple"); + test_repo.set_config("ssh.variant", "auto")?; + let repo_config = repo.config()?; + let variant = get_ssh_variant(&repo_config, r#"C:\sshexe"#); + assert!(matches!(variant, Variant::Simple)); + } + + Ok(()) + } + + #[test] + fn test_parse_shell_command() { + // Test case 1: Empty command string + let (sshexe, cmd, need_shell) = parse_shell_command("", "default_ssh"); + assert_eq!(sshexe, "default_ssh"); + assert_eq!(cmd, ""); + assert!(!need_shell); + + // Test case 2: Simple command + let (sshexe, cmd, need_shell) = parse_shell_command("ssh", "default_ssh"); + assert_eq!(sshexe, "ssh"); + assert_eq!(cmd, "ssh"); + assert!(need_shell); + + // Test case 3: Command with arguments + let (sshexe, cmd, need_shell) = parse_shell_command("ssh -i ~/.ssh/id_rsa", "default_ssh"); + assert_eq!(sshexe, "ssh"); + assert_eq!(cmd, "ssh -i ~/.ssh/id_rsa"); + assert!(need_shell); + + // Test case 4: Command with quoted arguments + let (sshexe, cmd, need_shell) = parse_shell_command("ssh -o \"StrictHostKeyChecking no\"", "default_ssh"); + assert_eq!(sshexe, "ssh"); + assert_eq!(cmd, "ssh -o \"StrictHostKeyChecking no\""); + assert!(need_shell); + + // Test case 5: Command with multiple spaces + let (sshexe, cmd, need_shell) = parse_shell_command(" ssh -v ", "default_ssh"); + assert_eq!(sshexe, "ssh"); + assert_eq!(cmd, " ssh -v "); + assert!(need_shell); + + // Test case 6: Command that is just whitespace + let (sshexe, cmd, need_shell) = parse_shell_command(" ", "default_ssh"); + assert_eq!(sshexe, "default_ssh"); + assert_eq!(cmd, ""); + assert!(!need_shell); + + // Test case 7: Command with a different executable name + let (sshexe, cmd, need_shell) = parse_shell_command("/usr/bin/custom_ssh", "ssh"); + assert_eq!(sshexe, "/usr/bin/custom_ssh"); + assert_eq!(cmd, "/usr/bin/custom_ssh"); + assert!(need_shell); + } + + #[test] + #[serial(env_var_write_read)] + fn test_get_sshexe_and_args_no_port() -> Result<()> { + let test_repo = TestRepo::new("main")?; + let repo = GitRepo::open(test_repo.path())?; + + let meta = SSHMetadata { + user_and_host: "git@hf.co".into(), + port: None, + arg_list: vec!["auth".into(), "org/repo".into(), "upload".into()], + }; + + // Test with default SSH variant (OpenSSH) + let (cmd, args, need_shell) = get_sshexe_and_args(&meta, &repo)?; + assert_eq!(cmd, DEFAULT_SSH_CMD); + assert_eq!(args, vec!["git@hf.co", "auth", "org/repo", "upload"]); + assert!(!need_shell); + + // Test with GIT_SSH_COMMAND + { + let _env = EnvVarGuard::set("GIT_SSH_COMMAND", "ssh -i ~/.ssh/id_rsa"); + let (cmd, args, need_shell) = get_sshexe_and_args(&meta, &repo)?; + assert_eq!(cmd, "ssh -i ~/.ssh/id_rsa"); + assert_eq!(args, vec!["git@hf.co", "auth", "org/repo", "upload"]); + assert!(need_shell); + } + + // Test with GIT_SSH + { + let _env = EnvVarGuard::set("GIT_SSH", "/usr/bin/custom_ssh"); + let (cmd, args, need_shell) = get_sshexe_and_args(&meta, &repo)?; + assert_eq!(cmd, "/usr/bin/custom_ssh"); + assert_eq!(args, vec!["git@hf.co", "auth", "org/repo", "upload"]); + assert!(!need_shell); + } + + // Test with ssh variant + { + let _env_ssh = EnvVarGuard::set("GIT_SSH", r#"C:\Program Files\Tortoiseplink.exe"#); + let _env_variant = EnvVarGuard::set("GIT_SSH_VARIANT", "tortoiseplink"); + let (cmd, args, need_shell) = get_sshexe_and_args(&meta, &repo)?; + assert_eq!(cmd, r#"C:\Program Files\Tortoiseplink.exe"#); + assert_eq!(args, vec!["-batch", "git@hf.co", "auth", "org/repo", "upload"]); + assert!(!need_shell); + } + + // Test with core.sshCommand + { + test_repo.set_config("core.sshCommand", "ssh -v")?; + let repo = GitRepo::open(test_repo.path())?; + let (cmd, args, need_shell) = get_sshexe_and_args(&meta, &repo)?; + assert_eq!(cmd, "ssh -v"); + assert_eq!(args, vec!["git@hf.co", "auth", "org/repo", "upload"]); + assert!(need_shell); + } + + // Test with core.sshCommand with quotes + { + test_repo.set_config("core.sshCommand", "ssh -o \"StrictHostKeyChecking no\"")?; + let repo = GitRepo::open(test_repo.path())?; + let (cmd, args, need_shell) = get_sshexe_and_args(&meta, &repo)?; + assert_eq!(cmd, "ssh -o \"StrictHostKeyChecking no\""); + assert_eq!(args, vec!["git@hf.co", "auth", "org/repo", "upload"]); + assert!(need_shell); + } + + Ok(()) + } + + #[test] + #[serial(env_var_write_read)] + fn test_get_sshexe_and_args_with_port() -> Result<()> { + let test_repo = TestRepo::new("main")?; + let repo = GitRepo::open(test_repo.path())?; + + let meta = SSHMetadata { + user_and_host: "git@hf.co".to_string(), + port: Some(2222), + arg_list: vec!["auth".into(), "org/repo".into(), "upload".into()], + }; + + // Test with default SSH variant (OpenSSH) + let (cmd, args, need_shell) = get_sshexe_and_args(&meta, &repo)?; + assert_eq!(cmd, DEFAULT_SSH_CMD); + assert_eq!(args, vec!["-p", "2222", "git@hf.co", "auth", "org/repo", "upload"]); + assert!(!need_shell); + + // Test with Putty variant + { + let _env = EnvVarGuard::set("GIT_SSH_VARIANT", "putty"); + let (cmd, args, need_shell) = get_sshexe_and_args(&meta, &repo)?; + assert_eq!(cmd, DEFAULT_SSH_CMD); + assert_eq!(args, vec!["-P", "2222", "git@hf.co", "auth", "org/repo", "upload"]); + assert!(!need_shell); + } + + // Test with Tortoise variant + { + let _env = EnvVarGuard::set("GIT_SSH_VARIANT", "tortoiseplink"); + let (cmd, args, need_shell) = get_sshexe_and_args(&meta, &repo)?; + assert_eq!(cmd, DEFAULT_SSH_CMD); + assert_eq!(args, vec!["-batch", "-P", "2222", "git@hf.co", "auth", "org/repo", "upload"]); + assert!(!need_shell); + } + + Ok(()) + } + + #[test] + #[serial(env_var_write_read)] + fn test_get_sshexe_and_args_simple_variant_error() -> Result<()> { + let test_repo = TestRepo::new("main")?; + let repo = GitRepo::open(test_repo.path())?; + + let meta = SSHMetadata { + user_and_host: "git@github.com".to_string(), + port: None, + arg_list: vec!["auth".into(), "org/repo".into(), "upload".into()], + }; + + // Test with 'simple' variant, which should error + { + let _env = EnvVarGuard::set("GIT_SSH_VARIANT", "simple"); + let result = get_sshexe_and_args(&meta, &repo); + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), GitXetError::NotSupported(_))); + } + + Ok(()) + } + + #[test] + #[serial(env_var_write_read)] + fn test_get_sshcmd_and_args_no_port() -> Result<()> { + let test_repo = TestRepo::new("main")?; + let repo = GitRepo::open(test_repo.path())?; + + let meta = SSHMetadata { + user_and_host: "git@hf.co".into(), + port: None, + arg_list: vec!["auth".into(), "org/repo".into(), "upload".into()], + }; + + // Test with default SSH variant (OpenSSH) + let (cmd, args) = get_sshcmd_and_args(&meta, &repo)?; + assert_eq!(cmd, DEFAULT_SSH_CMD); + assert_eq!(args, vec!["git@hf.co", "auth", "org/repo", "upload"]); + + // Test with GIT_SSH_COMMAND + { + let _env = EnvVarGuard::set("GIT_SSH_COMMAND", "ssh -i ~/.ssh/id_rsa"); + let (cmd, args) = get_sshcmd_and_args(&meta, &repo)?; + assert_eq!(cmd, "sh"); + assert_eq!(args, vec!["-c", "ssh -i ~/.ssh/id_rsa git@hf.co auth org/repo upload"]); + } + + // Test with GIT_SSH + { + let _env = EnvVarGuard::set("GIT_SSH", "/usr/bin/custom_ssh"); + let (cmd, args) = get_sshcmd_and_args(&meta, &repo)?; + assert_eq!(cmd, "/usr/bin/custom_ssh"); + assert_eq!(args, vec!["git@hf.co", "auth", "org/repo", "upload"]); + } + + // Test with ssh variant + { + let _env_ssh = EnvVarGuard::set("GIT_SSH", r#"C:\Program Files\Tortoiseplink.exe"#); + let _env_variant = EnvVarGuard::set("GIT_SSH_VARIANT", "tortoiseplink"); + let (cmd, args) = get_sshcmd_and_args(&meta, &repo)?; + assert_eq!(cmd, r#"C:\Program Files\Tortoiseplink.exe"#); + assert_eq!(args, vec!["-batch", "git@hf.co", "auth", "org/repo", "upload"]); + } + + // Test with core.sshCommand + { + test_repo.set_config("core.sshCommand", "ssh -v")?; + let repo = GitRepo::open(test_repo.path())?; + let (cmd, args) = get_sshcmd_and_args(&meta, &repo)?; + assert_eq!(cmd, "sh"); + assert_eq!(args, vec!["-c", "ssh -v git@hf.co auth org/repo upload"]); + } + + // Test with core.sshCommand with quotes + { + test_repo.set_config("core.sshCommand", "ssh -o \"StrictHostKeyChecking no\"")?; + let repo = GitRepo::open(test_repo.path())?; + let (cmd, args) = get_sshcmd_and_args(&meta, &repo)?; + assert_eq!(cmd, "sh"); + assert_eq!(args, vec!["-c", "ssh -o 'StrictHostKeyChecking no' git@hf.co auth org/repo upload"]); + } + + Ok(()) + } + + #[test] + #[serial(env_var_write_read)] + fn test_get_sshcmd_and_args_with_port() -> Result<()> { + let test_repo = TestRepo::new("main")?; + let repo = GitRepo::open(test_repo.path())?; + + let meta = SSHMetadata { + user_and_host: "git@hf.co".to_string(), + port: Some(2222), + arg_list: vec!["auth".into(), "org/repo".into(), "upload".into()], + }; + + // Test with default SSH variant (OpenSSH) + let (cmd, args) = get_sshcmd_and_args(&meta, &repo)?; + assert_eq!(cmd, DEFAULT_SSH_CMD); + assert_eq!(args, vec!["-p", "2222", "git@hf.co", "auth", "org/repo", "upload"]); + + // Test with Putty variant + { + let _env = EnvVarGuard::set("GIT_SSH_VARIANT", "putty"); + let (cmd, args) = get_sshcmd_and_args(&meta, &repo)?; + assert_eq!(cmd, DEFAULT_SSH_CMD); + assert_eq!(args, vec!["-P", "2222", "git@hf.co", "auth", "org/repo", "upload"]); + } + + // Test with Tortoise variant + { + let _env = EnvVarGuard::set("GIT_SSH_VARIANT", "tortoiseplink"); + let (cmd, args) = get_sshcmd_and_args(&meta, &repo)?; + assert_eq!(cmd, DEFAULT_SSH_CMD); + assert_eq!(args, vec!["-batch", "-P", "2222", "git@hf.co", "auth", "org/repo", "upload"]); + } + + Ok(()) + } +}