Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
329 changes: 328 additions & 1 deletion pyrefly/lib/commands/report.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
*/

use std::collections::HashMap;
use std::sync::Arc;

use clap::Parser;
use dupe::Dupe;
use pyrefly_config::args::ConfigOverrideArgs;
use pyrefly_config::finder::ConfigFinder;
use pyrefly_python::module::Module;
use pyrefly_python::module_name::ModuleName;
use pyrefly_python::short_identifier::ShortIdentifier;
use pyrefly_types::class::ClassDefIndex;
use pyrefly_util::forgetter::Forgetter;
use pyrefly_util::includes::Includes;
use regex::Regex;
Expand All @@ -21,9 +24,13 @@ use ruff_text_size::Ranged;
use ruff_text_size::TextRange;
use serde::Serialize;

use crate::alt::answers::Answers;
use crate::alt::types::class_metadata::ClassMro;
use crate::binding::binding::Binding;
use crate::binding::binding::BindingClass;
use crate::binding::binding::Key;
use crate::binding::binding::KeyClass;
use crate::binding::binding::KeyClassMro;
use crate::binding::binding::ReturnTypeKind;
use crate::binding::bindings::Bindings;
use crate::commands::check::Handles;
Expand Down Expand Up @@ -71,11 +78,27 @@ struct Function {
location: Location,
}

/// An incomplete attribute within a class (method with missing annotations)
#[derive(Debug, Serialize)]
struct IncompleteAttribute {
name: String,
declared_in: String,
}

/// Information about a class with incomplete annotations
#[derive(Debug, Serialize)]
struct ReportClass {
name: String,
incomplete_attributes: Vec<IncompleteAttribute>,
location: Location,
}

/// File report
#[derive(Debug, Serialize)]
struct FileReport {
line_count: usize,
functions: Vec<Function>,
classes: Vec<ReportClass>,
suppressions: Vec<Suppression>,
}

Expand Down Expand Up @@ -268,6 +291,158 @@ impl ReportArgs {
functions
}

/// Check if a function is completely annotated (has return annotation and all params annotated except self/cls)
fn is_function_completely_annotated(
bindings: &Bindings,
func_def: &ruff_python_ast::StmtFunctionDef,
) -> bool {
// Check return annotation
let return_key = Key::ReturnType(ShortIdentifier::new(&func_def.name));
let return_idx = bindings.key_to_idx(&return_key);
let has_return_annotation = if let Binding::ReturnType(ret) = bindings.get(return_idx) {
match &ret.kind {
ReturnTypeKind::ShouldValidateAnnotation { .. }
| ReturnTypeKind::ShouldTrustAnnotation { .. } => true,
_ => false,
}
} else {
false
};

if !has_return_annotation {
return false;
}

// Check all parameters (except self/cls which don't need annotations)
let all_params = Self::extract_parameters(&func_def.parameters);
for param in all_params {
let param_name = param.name.as_str();
// Skip self and cls parameters
if param_name == "self" || param_name == "cls" {
continue;
}
if param.annotation.is_none() {
return false;
}
}

true
}

fn parse_classes(
module: &Module,
bindings: Bindings,
answers: Arc<Answers>,
) -> Vec<ReportClass> {
let mut classes = Vec::new();
let module_prefix = if module.name() != ModuleName::unknown() {
format!("{}.", module.name())
} else {
String::new()
};
for class_idx in bindings.keys::<KeyClass>() {
let binding_class = bindings.get(class_idx);
let cls_binding = match binding_class {
BindingClass::ClassDef(cls) => cls,
BindingClass::FunctionalClassDef(..) => continue,
};
let class_type = match answers.get_idx(class_idx) {
Some(result) => match &result.0 {
Some(cls) => cls.clone(),
None => continue,
},
None => continue,
};
let class_name = {
let parent_path = module.display(&cls_binding.parent).to_string();
if parent_path.is_empty() {
format!("{}{}", module_prefix, cls_binding.def.name)
} else {
format!("{}{}.{}", module_prefix, parent_path, cls_binding.def.name)
}
};
let mro = answers
.get_idx(bindings.key_to_idx(&KeyClassMro(ClassDefIndex(class_type.index().0))))
.unwrap_or_else(|| Arc::new(ClassMro::Cyclic));
// Check methods defined directly on this class
let mut incomplete_attributes = Vec::new();
for idx in bindings.keys::<Key>() {
if let Key::Definition(_id) = bindings.idx_to_key(idx)
&& let Binding::Function(x, _pred, _class_meta) = bindings.get(idx)
{
let fun = bindings.get(bindings.get(*x).undecorated_idx);
if let Some(func_class_key) = fun.class_key {
if func_class_key != class_idx {
continue;
}
let method_name = fun.def.name.to_string();
if !Self::is_function_completely_annotated(&bindings, &fun.def) {
incomplete_attributes.push(IncompleteAttribute {
name: method_name.clone(),
declared_in: class_name.clone(),
});
}
}
}
}
// Check inherited methods
if let ClassMro::Resolved(ancestors) = &*mro {
for ancestor_class_type in ancestors {
let ancestor_class = ancestor_class_type.class_object();
let ancestor_name = {
let ancestor_module = ancestor_class.module();
let ancestor_module_prefix =
if ancestor_module.name() != ModuleName::unknown() {
format!("{}.", ancestor_module.name())
} else {
String::new()
};
let ancestor_parent_path = ancestor_module
.display(ancestor_class.qname().parent())
.to_string();
if ancestor_parent_path.is_empty() {
format!("{}{}", ancestor_module_prefix, ancestor_class.name())
} else {
format!(
"{}{}.{}",
ancestor_module_prefix,
ancestor_parent_path,
ancestor_class.name()
)
}
};
// Skip methods inherited from builtins
if ancestor_class.module_name().as_str() == "builtins" {
continue;
}
for field_name in ancestor_class.fields() {
let field_name_str = field_name.to_string();
// Skip if we already have this attribute listed (it has been overridden by the current class or another class in the MRO)
if incomplete_attributes
.iter()
.any(|a| a.name == field_name_str)
{
continue;
}
if !ancestor_class.is_field_annotated(field_name) {
incomplete_attributes.push(IncompleteAttribute {
name: field_name_str,
declared_in: ancestor_name.clone(),
});
}
}
}
}
let location = Self::range_to_location(module, cls_binding.def.range);
classes.push(ReportClass {
name: class_name,
incomplete_attributes,
location,
});
}
classes
}

fn run_inner(
files_to_check: Box<dyn Includes>,
config_finder: ConfigFinder,
Expand Down Expand Up @@ -298,16 +473,19 @@ impl ReportArgs {

if let Some(bindings) = transaction.get_bindings(&handle)
&& let Some(module) = transaction.get_module_info(&handle)
&& let Some(answers) = transaction.get_answers(&handle)
{
let line_count = module.lined_buffer().line_index().line_count();
let functions = Self::parse_functions(&module, bindings);
let functions = Self::parse_functions(&module, bindings.dupe());
let classes = Self::parse_classes(&module, bindings.dupe(), answers);
let suppressions = Self::parse_suppressions(&module);

report.insert(
handle.path().as_path().display().to_string(),
FileReport {
line_count,
functions,
classes,
suppressions,
},
);
Expand Down Expand Up @@ -462,4 +640,153 @@ def foo():
assert_eq!(functions.len(), 1);
assert_eq!(functions[0].name, "foo");
}

#[test]
fn test_parse_classes_with_incomplete_methods() {
let code = r#"
class Complete:
def method(self, x: int) -> bool:
return True

class Incomplete:
def method_unannotated(self, x):
pass
def method_partial(self, x: int):
pass
def method_complete(self, x: int) -> bool:
return True
"#;
let (state, handle_fn) = TestEnv::one("test", code)
.with_default_require_level(Require::Everything)
.to_state();
let handle = handle_fn("test");
let transaction = state.transaction();

let module = transaction.get_module_info(&handle).unwrap();
let bindings = transaction.get_bindings(&handle).unwrap();
let answers = transaction.get_answers(&handle).unwrap();

let classes = ReportArgs::parse_classes(&module, bindings, answers);

// Only Incomplete should be reported (Complete has all methods annotated)
assert_eq!(classes.len(), 1);
assert_eq!(classes[0].name, "test.Incomplete");

// Should have method_unannotated and method_partial as incomplete attributes
assert_eq!(classes[0].incomplete_attributes.len(), 2);

let attr_names: Vec<&str> = classes[0]
.incomplete_attributes
.iter()
.map(|a| a.name.as_str())
.collect();
assert!(attr_names.contains(&"method_unannotated"));
assert!(attr_names.contains(&"method_partial"));

// All should be declared in Incomplete
for attr in &classes[0].incomplete_attributes {
assert_eq!(attr.declared_in, "test.Incomplete");
}
}

#[test]
fn test_parse_classes_ignores_object_methods() {
let code = r#"
class MyClass:
def __init__(self):
pass
def __repr__(self):
return "MyClass"
def __eq__(self, other):
return True
def custom_method(self, x):
pass
"#;
let (state, handle_fn) = TestEnv::one("test", code)
.with_default_require_level(Require::Everything)
.to_state();
let handle = handle_fn("test");
let transaction = state.transaction();

let module = transaction.get_module_info(&handle).unwrap();
let bindings = transaction.get_bindings(&handle).unwrap();
let answers = transaction.get_answers(&handle).unwrap();

let classes = ReportArgs::parse_classes(&module, bindings, answers);

// MyClass should be reported because custom_method is incomplete
assert_eq!(classes.len(), 1);
assert_eq!(classes[0].name, "test.MyClass");

// Only custom_method should be reported, not __init__, __repr__, __eq__
assert_eq!(classes[0].incomplete_attributes.len(), 1);
assert_eq!(classes[0].incomplete_attributes[0].name, "custom_method");
}

#[test]
fn test_parse_classes_nested() {
let code = r#"
class Outer:
def method(self, x: int) -> bool:
return True

class Inner:
def inner_method(self, x):
pass
"#;
let (state, handle_fn) = TestEnv::one("test", code)
.with_default_require_level(Require::Everything)
.to_state();
let handle = handle_fn("test");
let transaction = state.transaction();

let module = transaction.get_module_info(&handle).unwrap();
let bindings = transaction.get_bindings(&handle).unwrap();
let answers = transaction.get_answers(&handle).unwrap();

let classes = ReportArgs::parse_classes(&module, bindings, answers);

// Only Inner should be reported (Outer is complete)
assert_eq!(classes.len(), 1);
assert_eq!(classes[0].name, "test.Outer.Inner");
assert_eq!(classes[0].incomplete_attributes.len(), 1);
assert_eq!(classes[0].incomplete_attributes[0].name, "inner_method");
}

#[test]
fn test_parse_classes_inheritance() {
let code = r#"
class Base:
def base_method(self, x):
pass

class Child(Base):
def child_method(self, x: int) -> bool:
return True
"#;
let (state, handle_fn) = TestEnv::one("test", code)
.with_default_require_level(Require::Everything)
.to_state();
let handle = handle_fn("test");
let transaction = state.transaction();

let module = transaction.get_module_info(&handle).unwrap();
let bindings = transaction.get_bindings(&handle).unwrap();
let answers = transaction.get_answers(&handle).unwrap();

let classes = ReportArgs::parse_classes(&module, bindings, answers);

// Both Base and Child should be reported
// Base has base_method incomplete
// Child inherits base_method from Base
assert!(!classes.is_empty());

// Find Base class
let base = classes.iter().find(|c| c.name == "test.Base");
assert!(base.is_some(), "Base class should be reported");
let base = base.unwrap();
assert_eq!(base.incomplete_attributes.len(), 1);
assert_eq!(base.incomplete_attributes[0].name, "base_method");
assert_eq!(base.incomplete_attributes[0].declared_in, "test.Base");
}
}
Loading