Skip to content

Commit 8847dcd

Browse files
fix
1 parent 4ff2b9d commit 8847dcd

File tree

4 files changed

+139
-15
lines changed

4 files changed

+139
-15
lines changed

crates/pyrefly_python/src/sys_info.rs

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,11 @@ use ruff_python_ast::ExprAttribute;
2626
use ruff_python_ast::ExprBooleanLiteral;
2727
use ruff_python_ast::ExprCall;
2828
use ruff_python_ast::ExprList;
29+
use ruff_python_ast::ExprName;
2930
use ruff_python_ast::ExprNumberLiteral;
3031
use ruff_python_ast::ExprSet;
3132
use ruff_python_ast::ExprSlice;
33+
use ruff_python_ast::ExprStringLiteral;
3234
use ruff_python_ast::ExprSubscript;
3335
use ruff_python_ast::Stmt;
3436
use ruff_python_ast::StmtIf;
@@ -267,6 +269,14 @@ impl SysInfo {
267269
pub fn type_checking(&self) -> bool {
268270
self.0.key().type_checking
269271
}
272+
273+
pub fn with_platform(&self, platform: PythonPlatform) -> Self {
274+
if self.type_checking() {
275+
SysInfo::new(self.version(), platform)
276+
} else {
277+
SysInfo::new_without_type_checking(self.version(), platform)
278+
}
279+
}
270280
}
271281

272282
impl<'de> Deserialize<'de> for SysInfo {
@@ -697,6 +707,83 @@ impl SysInfo {
697707
}
698708
}
699709

710+
pub fn module_platform_guard(body: &[Stmt]) -> Option<PythonPlatform> {
711+
body.iter().find_map(platform_guard_from_stmt)
712+
}
713+
714+
fn platform_guard_from_stmt(stmt: &Stmt) -> Option<PythonPlatform> {
715+
match stmt {
716+
Stmt::Assert(x) => platform_guard_from_assert(&x.test),
717+
Stmt::If(x) if x.elif_else_clauses.is_empty() && body_is_unconditional_raise(&x.body) => {
718+
platform_guard_from_not_eq(&x.test)
719+
}
720+
_ => None,
721+
}
722+
}
723+
724+
fn body_is_unconditional_raise(body: &[Stmt]) -> bool {
725+
matches!(body, [Stmt::Raise(_)])
726+
}
727+
728+
fn platform_guard_from_assert(test: &Expr) -> Option<PythonPlatform> {
729+
let (op, platform) = platform_compare(test)?;
730+
match op {
731+
CmpOp::Eq => Some(PythonPlatform::new(&platform)),
732+
_ => None,
733+
}
734+
}
735+
736+
fn platform_guard_from_not_eq(test: &Expr) -> Option<PythonPlatform> {
737+
let (op, platform) = platform_compare(test)?;
738+
match op {
739+
CmpOp::NotEq => Some(PythonPlatform::new(&platform)),
740+
_ => None,
741+
}
742+
}
743+
744+
fn platform_compare(test: &Expr) -> Option<(CmpOp, String)> {
745+
let Expr::Compare(compare) = test else {
746+
return None;
747+
};
748+
if compare.ops.len() != 1 || compare.comparators.len() != 1 {
749+
return None;
750+
}
751+
let op = compare.ops[0];
752+
if !matches!(op, CmpOp::Eq | CmpOp::NotEq) {
753+
return None;
754+
}
755+
let left = &compare.left;
756+
let right = &compare.comparators[0];
757+
let platform = extract_platform_literal(left, right)?;
758+
Some((op, platform))
759+
}
760+
761+
fn extract_platform_literal(left: &Expr, right: &Expr) -> Option<String> {
762+
if is_sys_platform(left) {
763+
string_literal(right)
764+
} else if is_sys_platform(right) {
765+
string_literal(left)
766+
} else {
767+
None
768+
}
769+
}
770+
771+
fn is_sys_platform(expr: &Expr) -> bool {
772+
matches!(
773+
expr,
774+
Expr::Attribute(ExprAttribute { value, attr, .. })
775+
if matches!(&**value, Expr::Name(ExprName { id, .. }) if id == "sys")
776+
&& attr.as_str() == "platform"
777+
)
778+
}
779+
780+
fn string_literal(expr: &Expr) -> Option<String> {
781+
match expr {
782+
Expr::StringLiteral(ExprStringLiteral { value, .. }) => Some(value.to_str().to_owned()),
783+
_ => None,
784+
}
785+
}
786+
700787
#[cfg(test)]
701788
mod tests {
702789
use super::*;

pyrefly/lib/binding/bindings.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ use pyrefly_python::module_name::ModuleName;
2020
use pyrefly_python::nesting_context::NestingContext;
2121
use pyrefly_python::short_identifier::ShortIdentifier;
2222
use pyrefly_python::sys_info::SysInfo;
23+
use pyrefly_python::sys_info::module_platform_guard;
2324
use pyrefly_types::callable::FuncDefIndex;
2425
use pyrefly_types::type_alias::TypeAliasIndex;
2526
use pyrefly_types::type_info::JoinStyle;
@@ -473,6 +474,9 @@ impl Bindings {
473474
enable_trace: bool,
474475
untyped_def_behavior: UntypedDefBehavior,
475476
) -> Self {
477+
let override_sys_info =
478+
module_platform_guard(&x.body).map(|platform| sys_info.with_platform(platform));
479+
let sys_info = override_sys_info.as_ref().unwrap_or(sys_info);
476480
let mut builder = BindingsBuilder {
477481
module_info: module_info.dupe(),
478482
lookup,

pyrefly/lib/state/state.rs

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ use pyrefly_python::module_name::ModuleName;
3434
use pyrefly_python::module_path::ModulePath;
3535
use pyrefly_python::module_path::ModulePathDetails;
3636
use pyrefly_python::sys_info::SysInfo;
37+
use pyrefly_python::sys_info::module_platform_guard;
3738
use pyrefly_types::type_alias::TypeAliasIndex;
3839
use pyrefly_util::arc_id::ArcId;
3940
use pyrefly_util::events::CategorizedEvents;
@@ -389,6 +390,14 @@ struct ModuleDataMut {
389390
rdeps: Mutex<HashSet<Handle>>,
390391
}
391392

393+
fn module_sys_info_override(
394+
sys_info: &SysInfo,
395+
ast: Option<&ruff_python_ast::ModModule>,
396+
) -> Option<SysInfo> {
397+
let ast = ast?;
398+
module_platform_guard(&ast.body).map(|platform| sys_info.with_platform(platform))
399+
}
400+
392401
impl ModuleData {
393402
/// Make a copy of the data that can be mutated.
394403
fn clone_for_mutation(&self) -> ModuleDataMut {
@@ -1014,11 +1023,18 @@ impl<'a> Transaction<'a> {
10141023
let require = guard.require();
10151024
let stdlib = self.get_stdlib(&module_data.handle);
10161025
let config = module_data.config.read();
1026+
let sys_info_override = module_sys_info_override(
1027+
module_data.handle.sys_info(),
1028+
module_data.state.get_ast().as_deref(),
1029+
);
1030+
let sys_info = sys_info_override
1031+
.as_ref()
1032+
.unwrap_or(module_data.handle.sys_info());
10171033
let ctx = Context {
10181034
require,
10191035
module: module_data.handle.module(),
10201036
path: module_data.handle.path(),
1021-
sys_info: module_data.handle.sys_info(),
1037+
sys_info,
10221038
memory: &self.memory_lookup(),
10231039
uniques: &self.data.state.uniques,
10241040
stdlib: &stdlib,
@@ -1402,13 +1418,17 @@ impl<'a> Transaction<'a> {
14021418
.dupe()
14031419
}
14041420

1405-
pub fn get_stdlib(&self, handle: &Handle) -> Arc<Stdlib> {
1421+
pub fn get_stdlib_for_sys_info(&self, sys_info: &SysInfo) -> Arc<Stdlib> {
14061422
if self.data.stdlib.len() == 1 {
14071423
// Since we know our one must exist, we can shortcut
14081424
return self.data.stdlib.first().unwrap().1.dupe();
14091425
}
14101426

1411-
self.data.stdlib.get(handle.sys_info()).unwrap().dupe()
1427+
self.data.stdlib.get(sys_info).unwrap().dupe()
1428+
}
1429+
1430+
pub fn get_stdlib(&self, handle: &Handle) -> Arc<Stdlib> {
1431+
self.get_stdlib_for_sys_info(handle.sys_info())
14121432
}
14131433

14141434
fn compute_stdlib(&mut self, sys_infos: SmallSet<SysInfo>) {
@@ -1983,14 +2003,17 @@ impl<'a> TransactionHandle<'a> {
19832003
path: Option<&ModulePath>,
19842004
dep: ModuleDep,
19852005
) -> FindingOrError<ArcId<ModuleDataMut>> {
2006+
let sys_info_override = module_sys_info_override(
2007+
self.module_data.handle.sys_info(),
2008+
self.module_data.state.get_ast().as_deref(),
2009+
);
2010+
let sys_info = sys_info_override
2011+
.as_ref()
2012+
.unwrap_or(self.module_data.handle.sys_info());
19862013
let handle = match path {
19872014
Some(path) => {
19882015
// Explicit path — already resolved. Bypass imports entirely.
1989-
FindingOrError::new_finding(Handle::new(
1990-
module,
1991-
path.dupe(),
1992-
self.module_data.handle.sys_info().dupe(),
1993-
))
2016+
FindingOrError::new_finding(Handle::new(module, path.dupe(), sys_info.dupe()))
19942017
}
19952018
None => {
19962019
// No path — needs find_import. Check imports cache first.
@@ -2010,13 +2033,7 @@ impl<'a> TransactionHandle<'a> {
20102033
finding
20112034
}
20122035
};
2013-
path.map(|path| {
2014-
Handle::new(
2015-
module,
2016-
path.dupe(),
2017-
self.module_data.handle.sys_info().dupe(),
2018-
)
2019-
})
2036+
path.map(|path| Handle::new(module, path.dupe(), sys_info.dupe()))
20202037
}
20212038
};
20222039

pyrefly/lib/test/sys_info.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,22 @@ assert_type(Z(), str)
204204
"#,
205205
);
206206

207+
testcase!(
208+
test_module_platform_guard_imports,
209+
TestEnv::one(
210+
"winonly",
211+
r#"
212+
import sys
213+
assert sys.platform == "win32"
214+
import winreg
215+
x = winreg.OpenKey
216+
"#,
217+
),
218+
r#"
219+
import winonly
220+
"#,
221+
);
222+
207223
testcase!(
208224
test_os_name,
209225
r#"

0 commit comments

Comments
 (0)