diff --git a/fixtures/.rustscan_scripts.toml b/fixtures/.rustscan_scripts.toml new file mode 100644 index 000000000..3f5c6e026 --- /dev/null +++ b/fixtures/.rustscan_scripts.toml @@ -0,0 +1,4 @@ +tags = ["core_approved", "example"] +ports = ["80", "443", "8080"] +developer = ["example", "https://example.org"] +directory = "fixtures/.rustscan_scripts" \ No newline at end of file diff --git a/src/scripts/mod.rs b/src/scripts/mod.rs index 410fd21a1..e21b427ee 100644 --- a/src/scripts/mod.rs +++ b/src/scripts/mod.rs @@ -108,17 +108,21 @@ pub fn init_scripts(scripts: &ScriptsRequired) -> Result> { scripts_to_run.push(default_script); } ScriptsRequired::Custom => { - let scripts_dir_base = - dirs::home_dir().ok_or_else(|| anyhow!("Could not infer scripts path."))?; - let script_paths = find_scripts(scripts_dir_base)?; + let script_config = ScriptConfig::read_config()?; + debug!("Script config \n{:?}", script_config); + + let script_dir_base = if let Some(config_directory) = &script_config.directory { + PathBuf::from(config_directory) + } else { + dirs::home_dir().ok_or_else(|| anyhow!("Could not infer scripts path."))? + }; + + let script_paths = find_scripts(script_dir_base)?; debug!("Scripts paths \n{:?}", script_paths); let parsed_scripts = parse_scripts(script_paths); debug!("Scripts parsed \n{:?}", parsed_scripts); - let script_config = ScriptConfig::read_config()?; - debug!("Script config \n{:?}", script_config); - // Only Scripts that contain all the tags found in ScriptConfig will be selected. if let Some(config_hashset) = script_config.tags { for script in parsed_scripts { @@ -316,8 +320,7 @@ fn execute_script(script: &str) -> Result { } } -pub fn find_scripts(mut path: PathBuf) -> Result> { - path.push(".rustscan_scripts"); +pub fn find_scripts(path: PathBuf) -> Result> { if path.is_dir() { debug!("Scripts folder found {}", &path.display()); let mut files_vec: Vec = Vec::new(); @@ -382,6 +385,7 @@ pub struct ScriptConfig { pub tags: Option>, pub ports: Option>, pub developer: Option>, + pub directory: Option, } #[cfg(not(tarpaulin_include))] @@ -400,7 +404,7 @@ impl ScriptConfig { #[cfg(test)] mod tests { - use super::{find_scripts, parse_scripts, Script, ScriptFile}; + use super::*; // Function for testing only, it inserts static values into ip and open_ports // Doesn't use impl in case it's implemented in the super module at some point @@ -418,7 +422,7 @@ mod tests { #[test] fn find_and_parse_scripts() { - let scripts = find_scripts("fixtures/".into()).unwrap(); + let scripts = find_scripts("fixtures/.rustscan_scripts".into()).unwrap(); let scripts = parse_scripts(scripts); assert_eq!(scripts.len(), 4); } @@ -515,4 +519,49 @@ mod tests { // output has a newline at the end by default, .trim() trims it assert_eq!(output.trim(), "Total args passed to fixtures/.rustscan_scripts/test_script.pl : 2\nArg # 1 : 127.0.0.1\nArg # 2 : 80,8080"); } + + #[test] + fn test_custom_directory_config() { + // Create test config + let config_str = r#" + tags = ["core_approved", "example"] + directory = "fixtures/.rustscan_scripts" + "#; + + let config: ScriptConfig = toml::from_str(config_str).unwrap(); + assert_eq!( + config.directory, + Some("fixtures/.rustscan_scripts".to_string()) + ); + + // Test that the directory is actually used + let script_dir_base = PathBuf::from(config.directory.unwrap()); + let scripts = find_scripts(script_dir_base).unwrap(); + + // Verify we found the test script + assert!(scripts.iter().any(|p| p + .file_name() + .and_then(|f| f.to_str()) + .map(|s| s == "test_script.txt") + .unwrap_or(false))); + } + + #[test] + fn test_default_directory_fallback() { + let config_str = r#" + tags = ["core_approved", "example"] + "#; + + let config: ScriptConfig = toml::from_str(config_str).unwrap(); + assert_eq!(config.directory, None); + + // Test fallback to home directory + let script_dir_base = if let Some(config_directory) = &config.directory { + PathBuf::from(config_directory) + } else { + dirs::home_dir().unwrap() + }; + + assert_eq!(script_dir_base, dirs::home_dir().unwrap()); + } }