From b3f418680b5e04d7c6be2daa57f31946b9b97c66 Mon Sep 17 00:00:00 2001 From: Austin Bonander Date: Tue, 7 Oct 2025 17:55:59 -0700 Subject: [PATCH] fix(macros): smarter `.env` loading, caching, and invalidation --- Cargo.lock | 42 +++-- Cargo.toml | 1 + sqlx-core/Cargo.toml | 11 +- sqlx-core/src/config/mod.rs | 11 +- sqlx-macros-core/Cargo.toml | 3 +- sqlx-macros-core/clippy.toml | 3 + sqlx-macros-core/src/lib.rs | 23 ++- sqlx-macros-core/src/query/cache.rs | 97 ++++++++++ sqlx-macros-core/src/query/data.rs | 125 +++++++------ sqlx-macros-core/src/query/metadata.rs | 162 +++++++++++++++++ sqlx-macros-core/src/query/mod.rs | 238 ++++--------------------- sqlx-mysql/Cargo.toml | 5 +- sqlx-postgres/Cargo.toml | 5 +- sqlx-sqlite/Cargo.toml | 3 +- 14 files changed, 450 insertions(+), 279 deletions(-) create mode 100644 sqlx-macros-core/clippy.toml create mode 100644 sqlx-macros-core/src/query/cache.rs create mode 100644 sqlx-macros-core/src/query/metadata.rs diff --git a/Cargo.lock b/Cargo.lock index 61d2e7d7b6..d611267ccc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1405,6 +1405,12 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" +[[package]] +name = "foldhash" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" + [[package]] name = "foreign-types" version = "0.3.2" @@ -1642,7 +1648,18 @@ checksum = "5971ac85611da7067dbfcabef3c70ebb5606018acd9e2a3903a0da507521e0d5" dependencies = [ "allocator-api2", "equivalent", - "foldhash", + "foldhash 0.1.5", +] + +[[package]] +name = "hashbrown" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5419bdc4f6a9207fbeba6d11b604d481addf78ecd10c11ad51e76c2f6482748d" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash 0.2.0", ] [[package]] @@ -3559,7 +3576,7 @@ dependencies = [ "futures-intrusive", "futures-io", "futures-util", - "hashbrown 0.15.4", + "hashbrown 0.16.0", "hashlink", "indexmap 2.10.0", "ipnet", @@ -3578,7 +3595,7 @@ dependencies = [ "smallvec", "smol", "sqlx", - "thiserror 2.0.12", + "thiserror 2.0.17", "time", "tokio", "tokio-stream", @@ -3613,7 +3630,7 @@ dependencies = [ "serde_json", "serde_with", "sqlx", - "thiserror 2.0.12", + "thiserror 2.0.17", "time", "tokio", "tower", @@ -3866,6 +3883,7 @@ dependencies = [ "sqlx-postgres", "sqlx-sqlite", "syn 2.0.104", + "thiserror 2.0.17", "tokio", "url", ] @@ -3908,7 +3926,7 @@ dependencies = [ "sqlx", "sqlx-core", "stringprep", - "thiserror 2.0.12", + "thiserror 2.0.17", "time", "tracing", "uuid", @@ -3953,7 +3971,7 @@ dependencies = [ "sqlx", "sqlx-core", "stringprep", - "thiserror 2.0.12", + "thiserror 2.0.17", "time", "tracing", "uuid", @@ -3980,7 +3998,7 @@ dependencies = [ "serde_urlencoded", "sqlx", "sqlx-core", - "thiserror 2.0.12", + "thiserror 2.0.17", "time", "tracing", "url", @@ -4164,11 +4182,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.12" +version = "2.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" +checksum = "f63587ca0f12b72a0600bcba1d40081f830876000bb46dd2337a3051618f4fc8" dependencies = [ - "thiserror-impl 2.0.12", + "thiserror-impl 2.0.17", ] [[package]] @@ -4184,9 +4202,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "2.0.12" +version = "2.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" +checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index b24b59cfa0..00d5d656c1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -190,6 +190,7 @@ uuid = "1.1.2" # Common utility crates cfg-if = "1.0.0" dotenvy = { version = "0.15.0", default-features = false } +thiserror = { version = "2.0.17", default-features = false, features = ["std"] } # Runtimes [workspace.dependencies.async-global-executor] diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index 58c5b67e05..e22c6d4fc3 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -62,6 +62,7 @@ rustls-native-certs = { version = "0.8.0", optional = true } # Type Integrations bit-vec = { workspace = true, optional = true } bigdecimal = { workspace = true, optional = true } +chrono = { workspace = true, optional = true } rust_decimal = { workspace = true, optional = true } time = { workspace = true, optional = true } ipnet = { workspace = true, optional = true } @@ -69,15 +70,14 @@ ipnetwork = { workspace = true, optional = true } mac_address = { workspace = true, optional = true } uuid = { workspace = true, optional = true } +# work around bug in async-fs 2.0.0, which references futures-lite dependency wrongly, see https://github.com/launchbadge/sqlx/pull/3791#issuecomment-3043363281 +async-fs = { version = "2.1", optional = true } async-io = { version = "2.4.1", optional = true } async-task = { version = "4.7.1", optional = true } -# work around bug in async-fs 2.0.0, which references futures-lite dependency wrongly, see https://github.com/launchbadge/sqlx/pull/3791#issuecomment-3043363281 -async-fs = { version = "2.1", optional = true } base64 = { version = "0.22.0", default-features = false, features = ["std"] } bytes = "1.1.0" cfg-if = { workspace = true } -chrono = { version = "0.4.34", default-features = false, features = ["clock"], optional = true } crc = { version = "3", optional = true } crossbeam-queue = "0.3.2" either = "1.6.1" @@ -93,7 +93,6 @@ serde_json = { version = "1.0.73", features = ["raw_value"], optional = true } toml = { version = "0.8.16", optional = true } sha2 = { version = "0.10.0", default-features = false, optional = true } #sqlformat = "0.2.0" -thiserror = "2.0.0" tokio-stream = { version = "0.1.8", features = ["fs"], optional = true } tracing = { version = "0.1.37", features = ["log"] } smallvec = "1.7.0" @@ -102,7 +101,9 @@ bstr = { version = "1.0", default-features = false, features = ["std"], optional hashlink = "0.10.0" indexmap = "2.0" event-listener = "5.2.0" -hashbrown = "0.15.0" +hashbrown = "0.16.0" + +thiserror.workspace = true [dev-dependencies] sqlx = { workspace = true, features = ["postgres", "sqlite", "mysql", "migrate", "macros", "time", "uuid"] } diff --git a/sqlx-core/src/config/mod.rs b/sqlx-core/src/config/mod.rs index 267a2f1ed1..960d0f434f 100644 --- a/sqlx-core/src/config/mod.rs +++ b/sqlx-core/src/config/mod.rs @@ -158,7 +158,16 @@ impl Config { /// * If the file exists but could not be read or parsed. /// * If the file exists but the `sqlx-toml` feature is disabled. pub fn try_from_crate_or_default() -> Result { - Self::read_from(get_crate_path()?).or_else(|e| { + Self::try_from_path_or_default(get_crate_path()?) + } + + /// Attempt to read `Config` from the path given, or return `Config::default()` if it does not exist. + /// + /// # Errors + /// * If the file exists but could not be read or parsed. + /// * If the file exists but the `sqlx-toml` feature is disabled. + pub fn try_from_path_or_default(path: PathBuf) -> Result { + Self::read_from(path).or_else(|e| { if let ConfigError::NotFound { .. } = e { Ok(Config::default()) } else { diff --git a/sqlx-macros-core/Cargo.toml b/sqlx-macros-core/Cargo.toml index 3bcbede6f4..8702555086 100644 --- a/sqlx-macros-core/Cargo.toml +++ b/sqlx-macros-core/Cargo.toml @@ -26,7 +26,7 @@ _sqlite = [] # SQLx features derive = [] -macros = [] +macros = ["thiserror"] migrate = ["sqlx-core/migrate"] sqlx-toml = ["sqlx-core/sqlx-toml", "sqlx-sqlite?/sqlx-toml"] @@ -66,6 +66,7 @@ tokio = { workspace = true, optional = true } cfg-if = { workspace = true} dotenvy = { workspace = true } +thiserror = { workspace = true, optional = true } hex = { version = "0.4.3" } heck = { version = "0.5" } diff --git a/sqlx-macros-core/clippy.toml b/sqlx-macros-core/clippy.toml new file mode 100644 index 0000000000..f303803661 --- /dev/null +++ b/sqlx-macros-core/clippy.toml @@ -0,0 +1,3 @@ +[[disallowed-methods]] +path = "std::env::var" +reason = "use `crate::env()` instead, which optionally calls `proc_macro::tracked_env::var()`" diff --git a/sqlx-macros-core/src/lib.rs b/sqlx-macros-core/src/lib.rs index 9d4204f814..c9eabfbbff 100644 --- a/sqlx-macros-core/src/lib.rs +++ b/sqlx-macros-core/src/lib.rs @@ -26,7 +26,7 @@ use crate::query::QueryDriver; pub type Error = Box; -pub type Result = std::result::Result; +pub type Result = std::result::Result; mod common; pub mod database; @@ -84,3 +84,24 @@ where } } } + +pub fn env(var: &str) -> Result { + env_opt(var)? + .ok_or_else(|| format!("env var {var:?} must be set to use the query macros").into()) +} + +pub fn env_opt(var: &str) -> Result> { + use std::env::VarError; + + #[cfg(any(sqlx_macros_unstable, procmacro2_semver_exempt))] + let res: Result = proc_macro::tracked_env::var(var); + + #[cfg(not(any(sqlx_macros_unstable, procmacro2_semver_exempt)))] + let res: Result = std::env::var(var); + + match res { + Ok(val) => Ok(Some(val)), + Err(VarError::NotPresent) => Ok(None), + Err(VarError::NotUnicode(_)) => Err(format!("env var {var:?} is not valid UTF-8").into()), + } +} diff --git a/sqlx-macros-core/src/query/cache.rs b/sqlx-macros-core/src/query/cache.rs new file mode 100644 index 0000000000..f44366bac2 --- /dev/null +++ b/sqlx-macros-core/src/query/cache.rs @@ -0,0 +1,97 @@ +use std::path::{Path, PathBuf}; +use std::sync::Mutex; +use std::time::SystemTime; + +/// A cached value derived from one or more files, which is automatically invalidated +/// if the modified-time of any watched file changes. +pub struct MtimeCache { + inner: Mutex>>, +} + +pub struct MtimeCacheBuilder { + file_mtimes: Vec<(PathBuf, Option)>, +} + +struct MtimeCacheInner { + builder: MtimeCacheBuilder, + cached: T, +} + +impl MtimeCache { + pub fn new() -> Self { + MtimeCache { + inner: Mutex::new(None), + } + } + + /// Get the cached value, or (re)initialize it if it does not exist or a file's mtime has changed. + pub fn get_or_try_init( + &self, + init: impl FnOnce(&mut MtimeCacheBuilder) -> Result, + ) -> Result { + let mut inner = self.inner.lock().unwrap_or_else(|e| { + // Reset the cache on-panic. + let mut locked = e.into_inner(); + *locked = None; + locked + }); + + if let Some(inner) = &*inner { + if !inner.builder.any_modified() { + return Ok(inner.cached.clone()); + } + } + + let mut builder = MtimeCacheBuilder::new(); + + let value = init(&mut builder)?; + + *inner = Some(MtimeCacheInner { + builder, + cached: value.clone(), + }); + + Ok(value) + } +} + +impl MtimeCacheBuilder { + fn new() -> Self { + MtimeCacheBuilder { + file_mtimes: Vec::new(), + } + } + + /// Add a file path to watch. + /// + /// The cached value will be automatically invalidated if the modified-time of the file changes, + /// or if the file does not exist but is created sometime after this call. + pub fn add_path(&mut self, path: PathBuf) { + let mtime = get_mtime(&path); + + #[cfg(any(sqlx_macros_unstable, procmacro2_semver_exempt))] + { + proc_macro::tracked_path::path(&path); + } + + self.file_mtimes.push((path, mtime)); + } + + fn any_modified(&self) -> bool { + for (path, expected_mtime) in &self.file_mtimes { + let actual_mtime = get_mtime(path); + + if expected_mtime != &actual_mtime { + return true; + } + } + + false + } +} + +fn get_mtime(path: &Path) -> Option { + std::fs::metadata(path) + .and_then(|metadata| metadata.modified()) + .ok() +} diff --git a/sqlx-macros-core/src/query/data.rs b/sqlx-macros-core/src/query/data.rs index 470f86f973..912236ae37 100644 --- a/sqlx-macros-core/src/query/data.rs +++ b/sqlx-macros-core/src/query/data.rs @@ -1,17 +1,18 @@ -use std::collections::HashMap; use std::fmt::{Debug, Display, Formatter}; use std::fs; use std::io::Write as _; use std::marker::PhantomData; use std::path::{Path, PathBuf}; -use std::sync::{LazyLock, Mutex}; +use std::sync::{Arc, LazyLock, Mutex}; use serde::{Serialize, Serializer}; use sqlx_core::database::Database; use sqlx_core::describe::Describe; +use sqlx_core::HashMap; use crate::database::DatabaseExt; +use crate::query::cache::MtimeCache; #[derive(serde::Serialize)] #[serde(bound(serialize = "Describe: serde::Serialize"))] @@ -64,7 +65,7 @@ impl Serialize for SerializeDbName { } } -static OFFLINE_DATA_CACHE: LazyLock>> = +static OFFLINE_DATA_CACHE: LazyLock>>>> = LazyLock::new(Default::default); /// Offline query data @@ -79,47 +80,33 @@ pub struct DynQueryData { impl DynQueryData { /// Loads a query given the path to its "query-.json" file. Subsequent calls for the same /// path are retrieved from an in-memory cache. - pub fn from_data_file(path: impl AsRef, query: &str) -> crate::Result { - let path = path.as_ref(); - - let mut cache = OFFLINE_DATA_CACHE + pub fn from_data_file(path: &Path, query: &str) -> crate::Result { + let cache = OFFLINE_DATA_CACHE .lock() // Just reset the cache on error .unwrap_or_else(|poison_err| { let mut guard = poison_err.into_inner(); *guard = Default::default(); guard - }); - if let Some(cached) = cache.get(path).cloned() { - if query != cached.query { - return Err("hash collision for saved query data".into()); - } - return Ok(cached); - } - - #[cfg(procmacro2_semver_exempt)] - { - let path = path.as_ref().canonicalize()?; - let path = path.to_str().ok_or_else(|| { - format!( - "query-.json path cannot be represented as a string: {:?}", - path - ) - })?; + }) + .entry_ref(path) + .or_insert_with(|| Arc::new(MtimeCache::new())) + .clone(); - proc_macro::tracked_path::path(path); - } + cache.get_or_try_init(|builder| { + builder.add_path(path.into()); - let offline_data_contents = fs::read_to_string(path) - .map_err(|e| format!("failed to read saved query path {}: {}", path.display(), e))?; - let dyn_data: DynQueryData = serde_json::from_str(&offline_data_contents)?; + let offline_data_contents = fs::read_to_string(path).map_err(|e| { + format!("failed to read saved query path {}: {}", path.display(), e) + })?; + let dyn_data: DynQueryData = serde_json::from_str(&offline_data_contents)?; - if query != dyn_data.query { - return Err("hash collision for saved query data".into()); - } + if query != dyn_data.query { + return Err("hash collision for saved query data".into()); + } - let _ = cache.insert(path.to_owned(), dyn_data.clone()); - Ok(dyn_data) + Ok(dyn_data) + }) } } @@ -149,41 +136,71 @@ where } } - pub(super) fn save_in(&self, dir: impl AsRef) -> crate::Result<()> { + pub(super) fn save_in(&self, dir: &Path) -> crate::Result<()> { use std::io::ErrorKind; - let path = dir.as_ref().join(format!("query-{}.json", self.hash)); - match std::fs::remove_file(&path) { - Ok(()) => {} - Err(err) - if matches!( - err.kind(), - ErrorKind::NotFound | ErrorKind::PermissionDenied, - ) => {} - Err(err) => return Err(format!("failed to delete {path:?}: {err:?}").into()), + let path = dir.join(format!("query-{}.json", self.hash)); + + if let Err(err) = fs::remove_file(&path) { + match err.kind() { + ErrorKind::NotFound | ErrorKind::PermissionDenied => (), + ErrorKind::NotADirectory => { + return Err(format!( + "sqlx offline path exists, but is not a directory: {dir:?}" + ) + .into()); + } + _ => return Err(format!("failed to delete {path:?}: {err:?}").into()), + } } - let mut file = match std::fs::OpenOptions::new() + + // Prevent tearing from concurrent invocations possibly trying to write the same file + // by using the existence of the file itself as a mutex. + // + // By deleting the file first and then using `.create_new(true)`, + // we guarantee that this only succeeds if another invocation hasn't concurrently + // re-created the file. + let mut file = match fs::OpenOptions::new() .write(true) .create_new(true) .open(&path) { Ok(file) => file, - // We overlapped with a concurrent invocation and the other one succeeded. - Err(err) if matches!(err.kind(), ErrorKind::AlreadyExists) => return Ok(()), Err(err) => { - return Err(format!("failed to exclusively create {path:?}: {err:?}").into()) + return match err.kind() { + // We overlapped with a concurrent invocation and the other one succeeded. + ErrorKind::AlreadyExists => Ok(()), + ErrorKind::NotFound => { + Err(format!("sqlx offline path does not exist: {dir:?}").into()) + } + ErrorKind::NotADirectory => Err(format!( + "sqlx offline path exists, but is not a directory: {dir:?}" + ) + .into()), + _ => Err(format!("failed to exclusively create {path:?}: {err:?}").into()), + }; } }; - let data = serde_json::to_string_pretty(self) - .map_err(|err| format!("failed to serialize query data: {err:?}"))?; - file.write_all(data.as_bytes()) - .map_err(|err| format!("failed to write query data to file: {err:?}"))?; + // From a quick survey of the files generated by `examples/postgres/axum-social-with-tests`, + // which are generally in the 1-2 KiB range, this seems like a safe bet to avoid + // lots of reallocations without using too much memory. + // + // As of writing, `serde_json::to_vec_pretty()` only allocates 128 bytes up-front. + let mut data = Vec::with_capacity(4096); + + serde_json::to_writer_pretty(&mut data, self).expect("BUG: failed to serialize query data"); // Ensure there is a newline at the end of the JSON file to avoid // accidental modification by IDE and make github diff tool happier. - file.write_all(b"\n") - .map_err(|err| format!("failed to append a newline to file: {err:?}"))?; + data.push(b'\n'); + + // This ideally writes the data in as few syscalls as possible. + file.write_all(&data) + .map_err(|err| format!("failed to write query data to file {path:?}: {err:?}"))?; + + // We don't really need to call `.sync_data()` since it's trivial to re-run the macro + // in the event a power loss results in incomplete flushing of the data to disk. Ok(()) } diff --git a/sqlx-macros-core/src/query/metadata.rs b/sqlx-macros-core/src/query/metadata.rs new file mode 100644 index 0000000000..1c85f5f394 --- /dev/null +++ b/sqlx-macros-core/src/query/metadata.rs @@ -0,0 +1,162 @@ +use sqlx_core::config::{Config, ConfigError}; +use std::hash::{BuildHasherDefault, DefaultHasher}; +use std::io; +use std::path::{Path, PathBuf}; +use std::sync::{Arc, Mutex}; + +use crate::query::cache::{MtimeCache, MtimeCacheBuilder}; +use sqlx_core::HashMap; + +pub struct Metadata { + pub manifest_dir: PathBuf, + pub config: Config, + env: MtimeCache>, + workspace_root: Arc>>, +} + +pub struct MacrosEnv { + pub database_url: Option, + pub offline_dir: Option, + pub offline: Option, +} + +impl Metadata { + pub fn env(&self) -> crate::Result> { + self.env + .get_or_try_init(|builder| load_env(&self.manifest_dir, &self.config, builder)) + } + + pub fn workspace_root(&self) -> PathBuf { + let mut root = self.workspace_root.lock().unwrap(); + if root.is_none() { + use serde::Deserialize; + use std::process::Command; + + let cargo = crate::env("CARGO").unwrap(); + + let output = Command::new(cargo) + .args(["metadata", "--format-version=1", "--no-deps"]) + .current_dir(&self.manifest_dir) + .env_remove("__CARGO_FIX_PLZ") + .output() + .expect("Could not fetch metadata"); + + #[derive(Deserialize)] + struct CargoMetadata { + workspace_root: PathBuf, + } + + let metadata: CargoMetadata = + serde_json::from_slice(&output.stdout).expect("Invalid `cargo metadata` output"); + + *root = Some(metadata.workspace_root); + } + root.clone().unwrap() + } +} + +pub fn try_for_crate() -> crate::Result> { + /// The `MtimeCache` in this type covers the config itself, + /// any changes to which will indirectly invalidate the loaded env vars as well. + #[expect(clippy::type_complexity)] + static METADATA: Mutex< + HashMap>>, BuildHasherDefault>, + > = Mutex::new(HashMap::with_hasher(BuildHasherDefault::new())); + + let manifest_dir = crate::env("CARGO_MANIFEST_DIR")?; + + let cache = METADATA + .lock() + .expect("BUG: we shouldn't panic while holding this lock") + .entry_ref(&manifest_dir) + .or_insert_with(|| Arc::new(MtimeCache::new())) + .clone(); + + cache.get_or_try_init(|builder| { + let manifest_dir = PathBuf::from(manifest_dir); + let config_path = manifest_dir.join("sqlx.toml"); + + builder.add_path(config_path.clone()); + + let config = Config::try_from_path_or_default(config_path)?; + + Ok(Arc::new(Metadata { + manifest_dir, + config, + env: MtimeCache::new(), + workspace_root: Default::default(), + })) + }) +} + +fn load_env( + manifest_dir: &Path, + config: &Config, + builder: &mut MtimeCacheBuilder, +) -> crate::Result> { + #[derive(thiserror::Error, Debug)] + #[error("error reading dotenv file {path:?}")] + struct DotenvError { + path: PathBuf, + #[source] + error: dotenvy::Error, + } + + let mut from_dotenv = MacrosEnv { + database_url: None, + offline_dir: None, + offline: None, + }; + + for dir in manifest_dir.ancestors() { + let path = dir.join(".env"); + + let dotenv = match dotenvy::from_path_iter(&path) { + Ok(iter) => { + builder.add_path(path.clone()); + iter + } + Err(dotenvy::Error::Io(e)) if e.kind() == io::ErrorKind::NotFound => { + builder.add_path(dir.to_path_buf()); + continue; + } + Err(e) => { + builder.add_path(path.clone()); + return Err(DotenvError { path, error: e }.into()); + } + }; + + for res in dotenv { + let (name, val) = res.map_err(|e| DotenvError { + path: path.clone(), + error: e, + })?; + + match &*name { + "SQLX_OFFLINE_DIR" => from_dotenv.offline_dir = Some(val.into()), + "SQLX_OFFLINE" => from_dotenv.offline = Some(is_truthy_bool(&val)), + _ if name == config.common.database_url_var() => { + from_dotenv.database_url = Some(val) + } + _ => continue, + } + } + } + + Ok(Arc::new(MacrosEnv { + // Make set variables take precedent + database_url: crate::env_opt(config.common.database_url_var())? + .or(from_dotenv.database_url), + offline_dir: crate::env_opt("SQLX_OFFLINE_DIR")? + .map(PathBuf::from) + .or(from_dotenv.offline_dir), + offline: crate::env_opt("SQLX_OFFLINE")? + .map(|val| is_truthy_bool(&val)) + .or(from_dotenv.offline), + })) +} + +/// Returns `true` if `val` is `"true"`, +fn is_truthy_bool(val: &str) -> bool { + val.eq_ignore_ascii_case("true") || val == "1" +} diff --git a/sqlx-macros-core/src/query/mod.rs b/sqlx-macros-core/src/query/mod.rs index 060a24b847..66620f9cea 100644 --- a/sqlx-macros-core/src/query/mod.rs +++ b/sqlx-macros-core/src/query/mod.rs @@ -1,7 +1,4 @@ -use std::collections::{hash_map, HashMap}; use std::path::{Path, PathBuf}; -use std::sync::{Arc, LazyLock, Mutex}; -use std::{fs, io}; use proc_macro2::TokenStream; use syn::Type; @@ -14,20 +11,25 @@ use sqlx_core::{column::Column, describe::Describe, type_info::TypeInfo}; use crate::database::DatabaseExt; use crate::query::data::{hash_string, DynQueryData, QueryData}; use crate::query::input::RecordType; +use crate::query::metadata::MacrosEnv; use either::Either; +use metadata::Metadata; use sqlx_core::config::Config; use url::Url; mod args; +mod cache; mod data; mod input; +mod metadata; mod output; #[derive(Copy, Clone)] pub struct QueryDriver { db_name: &'static str, url_schemes: &'static [&'static str], - expand: fn(&Config, QueryMacroInput, QueryDataSource) -> crate::Result, + expand: + fn(&Config, QueryMacroInput, QueryDataSource, Option<&Path>) -> crate::Result, } impl QueryDriver { @@ -68,138 +70,62 @@ impl<'a> QueryDataSource<'a> { } } } - -struct Metadata { - #[allow(unused)] - manifest_dir: PathBuf, - offline: bool, - database_url: Option, - offline_dir: Option, - config: Config, - workspace_root: Arc>>, -} - -impl Metadata { - pub fn workspace_root(&self) -> PathBuf { - let mut root = self.workspace_root.lock().unwrap(); - if root.is_none() { - use serde::Deserialize; - use std::process::Command; - - let cargo = env("CARGO").expect("`CARGO` must be set"); - - let output = Command::new(cargo) - .args(["metadata", "--format-version=1", "--no-deps"]) - .current_dir(&self.manifest_dir) - .env_remove("__CARGO_FIX_PLZ") - .output() - .expect("Could not fetch metadata"); - - #[derive(Deserialize)] - struct CargoMetadata { - workspace_root: PathBuf, - } - - let metadata: CargoMetadata = - serde_json::from_slice(&output.stdout).expect("Invalid `cargo metadata` output"); - - *root = Some(metadata.workspace_root); - } - root.clone().unwrap() - } -} - -static METADATA: LazyLock>> = LazyLock::new(Default::default); - -// If we are in a workspace, lookup `workspace_root` since `CARGO_MANIFEST_DIR` won't -// reflect the workspace dir: https://github.com/rust-lang/cargo/issues/3946 -fn init_metadata(manifest_dir: &String) -> crate::Result { - let manifest_dir: PathBuf = manifest_dir.into(); - - let (database_url, offline, offline_dir) = load_dot_env(&manifest_dir); - - let offline = env("SQLX_OFFLINE") - .ok() - .or(offline) - .map(|s| s.eq_ignore_ascii_case("true") || s == "1") - .unwrap_or(false); - - let offline_dir = env("SQLX_OFFLINE_DIR").ok().or(offline_dir); - - let config = Config::try_from_crate_or_default()?; - - let database_url = env(config.common.database_url_var()).ok().or(database_url); - - Ok(Metadata { - manifest_dir, - offline, - database_url, - offline_dir, - config, - workspace_root: Arc::new(Mutex::new(None)), - }) -} - pub fn expand_input<'a>( input: QueryMacroInput, drivers: impl IntoIterator, ) -> crate::Result { - let manifest_dir = env("CARGO_MANIFEST_DIR").expect("`CARGO_MANIFEST_DIR` must be set"); - - let mut metadata_lock = METADATA - .lock() - // Just reset the metadata on error - .unwrap_or_else(|poison_err| { - let mut guard = poison_err.into_inner(); - *guard = Default::default(); - guard - }); + let metadata = metadata::try_for_crate()?; - let metadata = match metadata_lock.entry(manifest_dir) { - hash_map::Entry::Occupied(occupied) => occupied.into_mut(), - hash_map::Entry::Vacant(vacant) => { - let metadata = init_metadata(vacant.key())?; - vacant.insert(metadata) - } - }; + let metadata_env = metadata.env()?; - let data_source = match &metadata { - Metadata { - offline: false, + let data_source = match &*metadata_env { + MacrosEnv { + offline: None | Some(false), database_url: Some(db_url), .. } => QueryDataSource::live(db_url)?, - Metadata { offline, .. } => { + MacrosEnv { + offline, + offline_dir, + .. + } => { // Try load the cached query metadata file. let filename = format!("query-{}.json", hash_string(&input.sql)); // Check SQLX_OFFLINE_DIR, then local .sqlx, then workspace .sqlx. let dirs = [ - |meta: &Metadata| meta.offline_dir.as_deref().map(PathBuf::from), - |meta: &Metadata| Some(meta.manifest_dir.join(".sqlx")), - |meta: &Metadata| Some(meta.workspace_root().join(".sqlx")), + |_: &Metadata, offline_dir: Option<&Path>| offline_dir.map(PathBuf::from), + |meta: &Metadata, _: Option<&Path>| Some(meta.manifest_dir.join(".sqlx")), + |meta: &Metadata, _: Option<&Path>| Some(meta.workspace_root().join(".sqlx")), ]; + let Some(data_file_path) = dirs .iter() - .filter_map(|path| path(metadata)) + .filter_map(|path| path(&metadata, offline_dir.as_deref())) .map(|path| path.join(&filename)) .find(|path| path.exists()) else { return Err( - if *offline { + if offline.unwrap_or(false) { "`SQLX_OFFLINE=true` but there is no cached data for this query, run `cargo sqlx prepare` to update the query cache or unset `SQLX_OFFLINE`" } else { "set `DATABASE_URL` to use query macros online, or run `cargo sqlx prepare` to update the query cache" }.into() ); }; + QueryDataSource::Cached(DynQueryData::from_data_file(&data_file_path, &input.sql)?) } }; for driver in drivers { if data_source.matches_driver(driver) { - return (driver.expand)(&metadata.config, input, data_source); + return (driver.expand)( + &metadata.config, + input, + data_source, + metadata_env.offline_dir.as_deref(), + ); } } @@ -224,19 +150,21 @@ fn expand_with( config: &Config, input: QueryMacroInput, data_source: QueryDataSource, + offline_dir: Option<&Path>, ) -> crate::Result where Describe: DescribeExt, { - let (query_data, offline): (QueryData, bool) = match data_source { - QueryDataSource::Cached(dyn_data) => (QueryData::from_dyn_data(dyn_data)?, true), + let (query_data, save_dir): (QueryData, Option<&Path>) = match data_source { + // If the build is offline, the cache is our input so it's pointless to also write data for it. + QueryDataSource::Cached(dyn_data) => (QueryData::from_dyn_data(dyn_data)?, None), QueryDataSource::Live { database_url, .. } => { let describe = DB::describe_blocking(&input.sql, database_url, &config.drivers)?; - (QueryData::from_describe(&input.sql, describe), false) + (QueryData::from_describe(&input.sql, describe), offline_dir) } }; - expand_with_data(config, input, query_data, offline) + expand_with_data(config, input, query_data, save_dir) } // marker trait for `Describe` that lets us conditionally require it to be `Serialize + Deserialize` @@ -257,7 +185,7 @@ fn expand_with_data( config: &Config, input: QueryMacroInput, data: QueryData, - offline: bool, + save_dir: Option<&Path>, ) -> crate::Result where Describe: DescribeExt, @@ -380,99 +308,9 @@ where } }; - // Store query metadata only if offline support is enabled but the current build is online. - // If the build is offline, the cache is our input so it's pointless to also write data for it. - if !offline { - // Only save query metadata if SQLX_OFFLINE_DIR is set manually or by `cargo sqlx prepare`. - // Note: in a cargo workspace this path is relative to the root. - if let Ok(dir) = env("SQLX_OFFLINE_DIR") { - let path = PathBuf::from(&dir); - - match fs::metadata(&path) { - Err(e) => { - if e.kind() != io::ErrorKind::NotFound { - // Can't obtain information about .sqlx - return Err(format!("{e}: {dir}").into()); - } - // .sqlx doesn't exist. - return Err(format!("sqlx offline path does not exist: {dir}").into()); - } - Ok(meta) => { - if !meta.is_dir() { - return Err(format!( - "sqlx offline path exists, but is not a directory: {dir}" - ) - .into()); - } - - // .sqlx exists and is a directory, store data. - data.save_in(path)?; - } - } - } + if let Some(save_dir) = save_dir { + data.save_in(save_dir)?; } Ok(ret_tokens) } - -/// Get the value of an environment variable, telling the compiler about it if applicable. -fn env(name: &str) -> Result { - #[cfg(procmacro2_semver_exempt)] - { - proc_macro::tracked_env::var(name) - } - - #[cfg(not(procmacro2_semver_exempt))] - { - std::env::var(name) - } -} - -/// Get `DATABASE_URL`, `SQLX_OFFLINE` and `SQLX_OFFLINE_DIR` from the `.env`. -fn load_dot_env(manifest_dir: &Path) -> (Option, Option, Option) { - let mut env_path = manifest_dir.join(".env"); - - // If a .env file exists at CARGO_MANIFEST_DIR, load environment variables from this, - // otherwise fallback to default dotenv file. - #[cfg_attr(not(procmacro2_semver_exempt), allow(unused_variables))] - let env_file = if env_path.exists() { - let res = dotenvy::from_path_iter(&env_path); - match res { - Ok(iter) => Some(iter), - Err(e) => panic!("failed to load environment from {env_path:?}, {e}"), - } - } else { - #[allow(unused_assignments)] - { - env_path = PathBuf::from(".env"); - } - dotenvy::dotenv_iter().ok() - }; - - let mut offline = None; - let mut database_url = None; - let mut offline_dir = None; - - if let Some(env_file) = env_file { - // tell the compiler to watch the `.env` for changes. - #[cfg(procmacro2_semver_exempt)] - if let Some(env_path) = env_path.to_str() { - proc_macro::tracked_path::path(env_path); - } - - for item in env_file { - let Ok((key, value)) = item else { - continue; - }; - - match key.as_str() { - "DATABASE_URL" => database_url = Some(value), - "SQLX_OFFLINE" => offline = Some(value), - "SQLX_OFFLINE_DIR" => offline_dir = Some(value), - _ => {} - }; - } - } - - (database_url, offline, offline_dir) -} diff --git a/sqlx-mysql/Cargo.toml b/sqlx-mysql/Cargo.toml index 52717c4207..2ed0738a24 100644 --- a/sqlx-mysql/Cargo.toml +++ b/sqlx-mysql/Cargo.toml @@ -55,7 +55,6 @@ base64 = { version = "0.22.0", default-features = false, features = ["std"] } bitflags = { version = "2", default-features = false, features = ["serde"] } byteorder = { version = "1.4.3", default-features = false, features = ["std"] } bytes = "1.1.0" -dotenvy = "0.15.5" either = "1.6.1" generic-array = { version = "0.14.4", default-features = false } hex = "0.4.3" @@ -65,10 +64,12 @@ memchr = { version = "2.4.1", default-features = false } percent-encoding = "2.1.0" smallvec = "1.7.0" stringprep = "0.1.2" -thiserror = "2.0.0" tracing = { version = "0.1.37", features = ["log"] } whoami = { version = "1.2.1", default-features = false } +dotenvy.workspace = true +thiserror.workspace = true + serde = { version = "1.0.144", optional = true } [dev-dependencies] diff --git a/sqlx-postgres/Cargo.toml b/sqlx-postgres/Cargo.toml index a70fb37d72..4abd252357 100644 --- a/sqlx-postgres/Cargo.toml +++ b/sqlx-postgres/Cargo.toml @@ -56,7 +56,6 @@ atoi = "2.0" base64 = { version = "0.22.0", default-features = false, features = ["std"] } bitflags = { version = "2", default-features = false } byteorder = { version = "1.4.3", default-features = false, features = ["std"] } -dotenvy = { workspace = true } hex = "0.4.3" home = "0.5.5" itoa = "1.0.1" @@ -65,10 +64,12 @@ memchr = { version = "2.4.1", default-features = false } num-bigint = { version = "0.4.3", optional = true } smallvec = { version = "1.7.0", features = ["serde"] } stringprep = "0.1.2" -thiserror = "2.0.0" tracing = { version = "0.1.37", features = ["log"] } whoami = { version = "1.2.1", default-features = false } +dotenvy.workspace = true +thiserror.workspace = true + serde = { version = "1.0.144", features = ["derive"] } serde_json = { version = "1.0.85", features = ["raw_value"] } diff --git a/sqlx-sqlite/Cargo.toml b/sqlx-sqlite/Cargo.toml index 4508e19ff6..6e04f1c2a6 100644 --- a/sqlx-sqlite/Cargo.toml +++ b/sqlx-sqlite/Cargo.toml @@ -85,7 +85,8 @@ atoi = "2.0" log = "0.4.18" tracing = { version = "0.1.37", features = ["log"] } -thiserror = "2.0.0" + +thiserror.workspace = true serde = { version = "1.0.145", features = ["derive"], optional = true } regex = { version = "1.5.5", optional = true }