Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 74 additions & 54 deletions sqlx-macros-core/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,19 +116,18 @@ static METADATA: LazyLock<Mutex<HashMap<String, Metadata>>> = LazyLock::new(Defa
fn init_metadata(manifest_dir: &String) -> crate::Result<Metadata> {
let manifest_dir: PathBuf = manifest_dir.into();

let (database_url, offline, offline_dir) = load_dot_env(&manifest_dir);
let env_file = DotEnvFile::load(&manifest_dir);

let offline = env("SQLX_OFFLINE")
.ok()
.or(offline)
let offline = env_file
.env("SQLX_OFFLINE")
.map(|s| s.eq_ignore_ascii_case("true") || s == "1")
.unwrap_or(false);

let offline_dir = env("SQLX_OFFLINE_DIR").ok().or(offline_dir);
let offline_dir = env_file.env("SQLX_OFFLINE_DIR").ok();

let config = Config::try_from_crate_or_default()?;

let database_url = env(config.common.database_url_var()).ok().or(database_url);
let database_url = env_file.env(config.common.database_url_var()).ok();

Ok(Metadata {
manifest_dir,
Expand Down Expand Up @@ -415,64 +414,85 @@ where
Ok(ret_tokens)
}

/// Get the value of an environment variable, telling the compiler about it if applicable.
fn env(name: &str) -> Result<String, std::env::VarError> {
#[cfg(procmacro2_semver_exempt)]
{
proc_macro::tracked_env::var(name)
}

#[cfg(not(procmacro2_semver_exempt))]
{
std::env::var(name)
}
#[derive(Default)]
struct DotEnvFile {
vars: HashMap<String, String>,
}

/// Get `DATABASE_URL`, `SQLX_OFFLINE` and `SQLX_OFFLINE_DIR` from the `.env`.
fn load_dot_env(manifest_dir: &Path) -> (Option<String>, Option<String>, Option<String>) {
let mut env_path = manifest_dir.join(".env");

// If a .env file exists at CARGO_MANIFEST_DIR, load environment variables from this,
// otherwise fallback to default dotenv file.
#[cfg_attr(not(procmacro2_semver_exempt), allow(unused_variables))]
let env_file = if env_path.exists() {
let res = dotenvy::from_path_iter(&env_path);
match res {
Ok(iter) => Some(iter),
Err(e) => panic!("failed to load environment from {env_path:?}, {e}"),
impl DotEnvFile {
fn load(manifest_dir: &Path) -> Self {
let mut result = Self::default();

let (found_dotenv, candidate_dotenv_paths) = Self::find(manifest_dir);

// Tell the compiler to watch the candidate `.env` paths for changes. It's important to
// watch them all, because there are several possible locations where a `.env` file
// might be read, and we want to react to changes in any of them.
#[cfg(procmacro2_semver_exempt)]
for path in &candidate_dotenv_paths {
if let Some(path) = path.to_str() {
proc_macro::tracked_path::path(path);
}
}
} else {
#[allow(unused_assignments)]

if let Some(dotenv_path) = found_dotenv
.then_some(candidate_dotenv_paths)
.iter()
.flatten()
.last()
{
env_path = PathBuf::from(".env");
for dotenv_var_result in dotenvy::from_path_iter(dotenv_path)
.ok()
.into_iter()
.flatten()
{
let Ok((key, value)) = dotenv_var_result else {
continue;
};

result.vars.insert(key, value);
}
}
dotenvy::dotenv_iter().ok()
};
result
}

let mut offline = None;
let mut database_url = None;
let mut offline_dir = None;
fn find(mut dir: &Path) -> (bool, Vec<PathBuf>) {
let mut candidate_files = vec![];

if let Some(env_file) = env_file {
// tell the compiler to watch the `.env` for changes.
#[cfg(procmacro2_semver_exempt)]
if let Some(env_path) = env_path.to_str() {
proc_macro::tracked_path::path(env_path);
}
loop {
candidate_files.push(dir.join(".env"));
let candidate_file = candidate_files.last().unwrap();

for item in env_file {
let Ok((key, value)) = item else {
continue;
};
if candidate_file.is_file() {
return (true, candidate_files);
}

match key.as_str() {
"DATABASE_URL" => database_url = Some(value),
"SQLX_OFFLINE" => offline = Some(value),
"SQLX_OFFLINE_DIR" => offline_dir = Some(value),
_ => {}
};
if let Some(parent) = dir.parent() {
dir = parent;
} else {
return (false, candidate_files);
}
}
}

(database_url, offline, offline_dir)
/// Get the value of an environment variable, telling the compiler about it if applicable.
fn env(&self, name: &str) -> Result<String, std::env::VarError> {
if let Some(val) = self.vars.get(name) {
Ok(val.clone())
} else {
env(name)
}
}
}

fn env(name: &str) -> Result<String, std::env::VarError> {
#[cfg(procmacro2_semver_exempt)]
{
proc_macro::tracked_env::var(name)
}

#[cfg(not(procmacro2_semver_exempt))]
{
std::env::var(name)
}
}
Loading