Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions crates/pyrefly_python/src/sys_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ use ruff_python_ast::ExprAttribute;
use ruff_python_ast::ExprBooleanLiteral;
use ruff_python_ast::ExprCall;
use ruff_python_ast::ExprList;
use ruff_python_ast::ExprName;
use ruff_python_ast::ExprNumberLiteral;
use ruff_python_ast::ExprSet;
use ruff_python_ast::ExprSlice;
use ruff_python_ast::ExprStringLiteral;
use ruff_python_ast::ExprSubscript;
use ruff_python_ast::Stmt;
use ruff_python_ast::StmtIf;
Expand Down Expand Up @@ -267,6 +269,14 @@ impl SysInfo {
pub fn type_checking(&self) -> bool {
self.0.key().type_checking
}

pub fn with_platform(&self, platform: PythonPlatform) -> Self {
if self.type_checking() {
SysInfo::new(self.version(), platform)
} else {
SysInfo::new_without_type_checking(self.version(), platform)
}
}
}

impl<'de> Deserialize<'de> for SysInfo {
Expand Down Expand Up @@ -697,6 +707,83 @@ impl SysInfo {
}
}

pub fn module_platform_guard(body: &[Stmt]) -> Option<PythonPlatform> {
body.iter().find_map(platform_guard_from_stmt)
Copy link

Copilot AI Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

module_platform_guard scans the entire module body (iter().find_map(...)) and will treat any matching assert sys.platform == ... / if sys.platform != ...: raise anywhere in the file as a module-level guard. That can incorrectly override SysInfo for modules that merely perform a later runtime check. Consider restricting the scan to the leading top-level statements (e.g., after an optional module docstring and possibly import sys) so only true module-guard patterns trigger an override.

Suggested change
body.iter().find_map(platform_guard_from_stmt)
for stmt in body {
if let Some(platform) = platform_guard_from_stmt(stmt) {
return Some(platform);
}
match stmt {
// Allow leading trivial statements (e.g. module docstring and imports)
Stmt::Expr(_) | Stmt::Import(_) | Stmt::ImportFrom(_) => continue,
// Any other non-guard statement means there is no module-level platform guard.
_ => break,
}
}
None

Copilot uses AI. Check for mistakes.
}

fn platform_guard_from_stmt(stmt: &Stmt) -> Option<PythonPlatform> {
match stmt {
Stmt::Assert(x) => platform_guard_from_assert(&x.test),
Stmt::If(x) if x.elif_else_clauses.is_empty() && body_is_unconditional_raise(&x.body) => {
platform_guard_from_not_eq(&x.test)
}
_ => None,
}
}

fn body_is_unconditional_raise(body: &[Stmt]) -> bool {
matches!(body, [Stmt::Raise(_)])
}

fn platform_guard_from_assert(test: &Expr) -> Option<PythonPlatform> {
let (op, platform) = platform_compare(test)?;
match op {
CmpOp::Eq => Some(PythonPlatform::new(&platform)),
_ => None,
}
}

fn platform_guard_from_not_eq(test: &Expr) -> Option<PythonPlatform> {
let (op, platform) = platform_compare(test)?;
match op {
CmpOp::NotEq => Some(PythonPlatform::new(&platform)),
_ => None,
}
}

fn platform_compare(test: &Expr) -> Option<(CmpOp, String)> {
let Expr::Compare(compare) = test else {
return None;
};
if compare.ops.len() != 1 || compare.comparators.len() != 1 {
return None;
}
let op = compare.ops[0];
if !matches!(op, CmpOp::Eq | CmpOp::NotEq) {
return None;
}
let left = &compare.left;
let right = &compare.comparators[0];
let platform = extract_platform_literal(left, right)?;
Some((op, platform))
}

fn extract_platform_literal(left: &Expr, right: &Expr) -> Option<String> {
if is_sys_platform(left) {
string_literal(right)
} else if is_sys_platform(right) {
string_literal(left)
} else {
None
}
}

fn is_sys_platform(expr: &Expr) -> bool {
matches!(
expr,
Expr::Attribute(ExprAttribute { value, attr, .. })
if matches!(&**value, Expr::Name(ExprName { id, .. }) if id == "sys")
&& attr.as_str() == "platform"
)
}

fn string_literal(expr: &Expr) -> Option<String> {
match expr {
Expr::StringLiteral(ExprStringLiteral { value, .. }) => Some(value.to_str().to_owned()),
_ => None,
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
4 changes: 4 additions & 0 deletions pyrefly/lib/binding/bindings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use pyrefly_python::module_name::ModuleName;
use pyrefly_python::nesting_context::NestingContext;
use pyrefly_python::short_identifier::ShortIdentifier;
use pyrefly_python::sys_info::SysInfo;
use pyrefly_python::sys_info::module_platform_guard;
use pyrefly_types::callable::FuncDefIndex;
use pyrefly_types::type_alias::TypeAliasIndex;
use pyrefly_types::type_info::JoinStyle;
Expand Down Expand Up @@ -473,6 +474,9 @@ impl Bindings {
enable_trace: bool,
untyped_def_behavior: UntypedDefBehavior,
) -> Self {
let override_sys_info =
module_platform_guard(&x.body).map(|platform| sys_info.with_platform(platform));
let sys_info = override_sys_info.as_ref().unwrap_or(sys_info);
let mut builder = BindingsBuilder {
module_info: module_info.dupe(),
lookup,
Expand Down
54 changes: 39 additions & 15 deletions pyrefly/lib/state/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ use pyrefly_python::module_name::ModuleName;
use pyrefly_python::module_path::ModulePath;
use pyrefly_python::module_path::ModulePathDetails;
use pyrefly_python::sys_info::SysInfo;
use pyrefly_python::sys_info::module_platform_guard;
use pyrefly_types::type_alias::TypeAliasIndex;
use pyrefly_util::arc_id::ArcId;
use pyrefly_util::events::CategorizedEvents;
Expand Down Expand Up @@ -365,6 +366,7 @@ impl ModuleDeps {
struct ModuleData {
handle: Handle,
config: ArcId<ConfigFile>,
effective_sys_info: SysInfo,
state: ModuleState,
imports: HashMap<ModuleName, FindingOrError<ModulePath>, BuildNoHash>,
deps: HashMap<Handle, ModuleDeps>,
Expand All @@ -375,6 +377,7 @@ struct ModuleData {
struct ModuleDataMut {
handle: Handle,
config: RwLock<ArcId<ConfigFile>>,
effective_sys_info: RwLock<SysInfo>,
state: ModuleStateMut,
/// Import resolution cache: module names from import statements → resolved paths.
/// Only contains deps that were resolved via `find_import`.
Expand All @@ -389,12 +392,21 @@ struct ModuleDataMut {
rdeps: Mutex<HashSet<Handle>>,
}

fn module_sys_info_override(
sys_info: &SysInfo,
ast: Option<&ruff_python_ast::ModModule>,
) -> Option<SysInfo> {
let ast = ast?;
module_platform_guard(&ast.body).map(|platform| sys_info.with_platform(platform))
}

impl ModuleData {
/// Make a copy of the data that can be mutated.
fn clone_for_mutation(&self) -> ModuleDataMut {
ModuleDataMut {
handle: self.handle.dupe(),
config: RwLock::new(self.config.dupe()),
effective_sys_info: RwLock::new(self.effective_sys_info.dupe()),
state: self.state.clone_for_mutation(),
imports: RwLock::new(self.imports.clone()),
deps: RwLock::new(self.deps.clone()),
Expand All @@ -405,9 +417,11 @@ impl ModuleData {

impl ModuleDataMut {
fn new(handle: Handle, require: Require, config: ArcId<ConfigFile>, now: Epoch) -> Self {
let effective_sys_info = handle.sys_info().dupe();
Self {
handle,
config: RwLock::new(config),
effective_sys_info: RwLock::new(effective_sys_info),
state: ModuleStateMut::new(require, now),
imports: Default::default(),
deps: Default::default(),
Expand All @@ -421,6 +435,7 @@ impl ModuleDataMut {
let ModuleDataMut {
handle,
config,
effective_sys_info,
state,
imports,
deps,
Expand All @@ -433,13 +448,26 @@ impl ModuleDataMut {
ModuleData {
handle: handle.dupe(),
config: config.read().dupe(),
effective_sys_info: effective_sys_info.read().dupe(),
state,
imports,
deps,
rdeps,
}
}

fn effective_sys_info(&self) -> SysInfo {
if let Some(ast) = self.state.get_ast().as_deref() {
let base = self.handle.sys_info();
let effective =
module_sys_info_override(base, Some(ast)).unwrap_or_else(|| base.dupe());
*self.effective_sys_info.write() = effective.dupe();
effective
} else {
self.effective_sys_info.read().dupe()
}
}

/// Look up how this module depends on a specific source handle.
/// Returns the `ModuleDep` if this module depends on `source_handle`, or `None` if not found.
fn get_depends_on(&self, source_handle: &Handle) -> Option<ModuleDeps> {
Expand Down Expand Up @@ -1014,11 +1042,12 @@ impl<'a> Transaction<'a> {
let require = guard.require();
let stdlib = self.get_stdlib(&module_data.handle);
let config = module_data.config.read();
let sys_info = module_data.effective_sys_info();
let ctx = Context {
require,
module: module_data.handle.module(),
path: module_data.handle.path(),
sys_info: module_data.handle.sys_info(),
sys_info: &sys_info,
memory: &self.memory_lookup(),
uniques: &self.data.state.uniques,
stdlib: &stdlib,
Expand Down Expand Up @@ -1402,13 +1431,17 @@ impl<'a> Transaction<'a> {
.dupe()
}

pub fn get_stdlib(&self, handle: &Handle) -> Arc<Stdlib> {
pub fn get_stdlib_for_sys_info(&self, sys_info: &SysInfo) -> Arc<Stdlib> {
if self.data.stdlib.len() == 1 {
// Since we know our one must exist, we can shortcut
return self.data.stdlib.first().unwrap().1.dupe();
}

self.data.stdlib.get(handle.sys_info()).unwrap().dupe()
self.data.stdlib.get(sys_info).unwrap().dupe()
}
Comment on lines +1434 to +1441
Copy link

Copilot AI Feb 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_stdlib_for_sys_info returns the only cached stdlib whenever self.data.stdlib.len() == 1, even if the requested sys_info is different. With platform-guarded modules you can now create Handles with a different platform than the initial run, so this shortcut can silently return a stdlib for the wrong platform (or mask missing entries). Consider removing the shortcut or only using it when the sole key matches sys_info (and otherwise ensure the stdlib for sys_info is computed/available).

Copilot uses AI. Check for mistakes.

pub fn get_stdlib(&self, handle: &Handle) -> Arc<Stdlib> {
self.get_stdlib_for_sys_info(handle.sys_info())
}

fn compute_stdlib(&mut self, sys_infos: SmallSet<SysInfo>) {
Expand Down Expand Up @@ -1983,14 +2016,11 @@ impl<'a> TransactionHandle<'a> {
path: Option<&ModulePath>,
dep: ModuleDep,
) -> FindingOrError<ArcId<ModuleDataMut>> {
let sys_info = self.module_data.effective_sys_info();
let handle = match path {
Some(path) => {
// Explicit path — already resolved. Bypass imports entirely.
FindingOrError::new_finding(Handle::new(
module,
path.dupe(),
self.module_data.handle.sys_info().dupe(),
))
FindingOrError::new_finding(Handle::new(module, path.dupe(), sys_info.dupe()))
}
None => {
// No path — needs find_import. Check imports cache first.
Expand All @@ -2010,13 +2040,7 @@ impl<'a> TransactionHandle<'a> {
finding
}
};
path.map(|path| {
Handle::new(
module,
path.dupe(),
self.module_data.handle.sys_info().dupe(),
)
})
path.map(|path| Handle::new(module, path.dupe(), sys_info.dupe()))
}
};

Expand Down
16 changes: 16 additions & 0 deletions pyrefly/lib/test/sys_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,22 @@ assert_type(Z(), str)
"#,
);

testcase!(
test_module_platform_guard_imports,
TestEnv::one(
"winonly",
r#"
import sys
assert sys.platform == "win32"
import winreg
x = winreg.OpenKey
"#,
),
r#"
import winonly
"#,
);

testcase!(
test_os_name,
r#"
Expand Down
Loading