Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 58 additions & 33 deletions crates/pyrefly_types/src/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

//! Display a type. The complexity comes from if we have two classes with the same name,
//! we want to display disambiguating information (e.g. module name or location).
use std::borrow::Cow;
use std::fmt;
use std::fmt::Display;

use pyrefly_python::module::TextRangeWithModule;
use pyrefly_python::module_name::ModuleName;
use pyrefly_python::qname::QName;
use pyrefly_util::display::Fmt;
Expand All @@ -29,6 +29,7 @@ use crate::literal::Lit;
use crate::tuple::Tuple;
use crate::type_output::DisplayOutput;
use crate::type_output::OutputWithLocations;
use crate::type_output::TypeLabelPart;
use crate::type_output::TypeOutput;
use crate::types::AnyStyle;
use crate::types::BoundMethod;
Expand Down Expand Up @@ -237,11 +238,12 @@ impl<'a> TypeDisplayContext<'a> {
name: &str,
output: &mut impl TypeOutput,
) -> fmt::Result {
if self.always_display_module_name {
output.write_str(&format!("{}.{}", module, name))
} else {
output.write_str(name)
}
let module_name = ModuleName::from_str(module);
output.write_symbol(
module_name,
Cow::Borrowed(name),
self.always_display_module_name,
)
}

/// Helper function to format a sequence of types with a separator.
Expand Down Expand Up @@ -302,7 +304,8 @@ impl<'a> TypeDisplayContext<'a> {
output.write_targs(class_type.targs())
}
Type::TypedDict(typed_dict) => {
output.write_str("TypedDict[")?;
self.maybe_fmt_with_module("typing", "TypedDict", output)?;
output.write_str("[")?;
output.write_qname(typed_dict.qname())?;
output.write_targs(typed_dict.targs())?;
output.write_str("]")
Expand Down Expand Up @@ -677,7 +680,10 @@ impl<'a> TypeDisplayContext<'a> {
output.write_str(&format!("{q}"))?;
output.write_str("]")
}
Type::SpecialForm(x) => output.write_str(&format!("{x}")),
Type::SpecialForm(x) => {
let name = x.to_string();
self.maybe_fmt_with_module("typing", &name, output)
}
Type::Ellipsis => output.write_str("Ellipsis"),
Type::Any(style) => match style {
AnyStyle::Explicit => self.maybe_fmt_with_module("typing", "Any", output),
Expand Down Expand Up @@ -746,7 +752,7 @@ impl Type {
c.display(self).to_string()
}

pub fn get_types_with_locations(&self) -> Vec<(String, Option<TextRangeWithModule>)> {
pub fn get_types_with_locations(&self) -> Vec<TypeLabelPart> {
let ctx = TypeDisplayContext::new(&[self]);
let mut output = OutputWithLocations::new(&ctx);
ctx.fmt_helper_generic(self, false, &mut output).unwrap();
Expand Down Expand Up @@ -1616,32 +1622,30 @@ def overloaded_func[T](
}

// Helper functions for testing get_types_with_location
fn get_parts(t: &Type) -> Vec<(String, Option<TextRangeWithModule>)> {
fn get_parts(t: &Type) -> Vec<TypeLabelPart> {
let ctx = TypeDisplayContext::new(&[t]);
let output = ctx.get_types_with_location(t, false);
output.parts().to_vec()
}

fn parts_to_string(parts: &[(String, Option<TextRangeWithModule>)]) -> String {
parts.iter().map(|(s, _)| s.as_str()).collect::<String>()
fn parts_to_string(parts: &[TypeLabelPart]) -> String {
parts
.iter()
.map(|part| part.text.as_str())
.collect::<String>()
}

fn assert_part_has_location(
parts: &[(String, Option<TextRangeWithModule>)],
name: &str,
module: &str,
position: u32,
) {
let part = parts.iter().find(|(s, _)| s == name);
fn assert_part_has_location(parts: &[TypeLabelPart], name: &str, module: &str, position: u32) {
let part = parts.iter().find(|part| part.text == name);
assert!(part.is_some(), "Should have {} in parts", name);
let (_, location) = part.unwrap();
assert!(location.is_some(), "{} should have location", name);
let loc = location.as_ref().unwrap();
let loc = part.unwrap().location.as_ref();
assert!(loc.is_some(), "{} should have location", name);
let loc = loc.unwrap();
assert_eq!(loc.module.name().as_str(), module);
assert_eq!(loc.range.start().to_u32(), position);
}

fn assert_output_contains(parts: &[(String, Option<TextRangeWithModule>)], needle: &str) {
fn assert_output_contains(parts: &[TypeLabelPart], needle: &str) {
let full_str = parts_to_string(parts);
assert!(
full_str.contains(needle),
Expand Down Expand Up @@ -1670,9 +1674,12 @@ def overloaded_func[T](
let t = Type::ClassType(ClassType::new(foo, TArgs::new(tparams, vec![inner_type])));
let parts = get_parts(&t);

assert_eq!(parts[0].0, "Foo");
assert_eq!(parts[0].text, "Foo");
assert_part_has_location(&parts, "Foo", "test.module", 10);
assert!(parts.iter().any(|(s, _)| s == "Bar"), "Should have Bar");
assert!(
parts.iter().any(|part| part.text == "Bar"),
"Should have Bar"
);
}

#[test]
Expand All @@ -1681,8 +1688,11 @@ def overloaded_func[T](
let t = tvar.to_type();
let parts = get_parts(&t);

assert_eq!(parts[0].0, "TypeVar[");
assert!(parts[0].1.is_none(), "TypeVar[ should not have location");
assert_eq!(parts[0].text, "TypeVar[");
assert!(
parts[0].location.is_none(),
"TypeVar[ should not have location"
);
assert_part_has_location(&parts, "T", "test.module", 15);
}

Expand All @@ -1698,8 +1708,14 @@ def overloaded_func[T](
let parts1 = ctx.get_types_with_location(&t1, false).parts().to_vec();
let parts2 = ctx.get_types_with_location(&t2, false).parts().to_vec();

let loc1 = parts1.iter().find_map(|(_, loc)| loc.as_ref()).unwrap();
let loc2 = parts2.iter().find_map(|(_, loc)| loc.as_ref()).unwrap();
let loc1 = parts1
.iter()
.find_map(|part| part.location.as_ref())
.unwrap();
let loc2 = parts2
.iter()
.find_map(|part| part.location.as_ref())
.unwrap();
assert_ne!(
loc1.range.start().to_u32(),
loc2.range.start().to_u32(),
Expand All @@ -1714,6 +1730,11 @@ def overloaded_func[T](

assert_output_contains(&parts, "Literal");
assert_output_contains(&parts, "True");
let literal_part = parts.iter().find(|part| part.text == "Literal");
assert!(literal_part.is_some());
let symbol = literal_part.unwrap().symbol.as_ref().unwrap();
assert_eq!(symbol.module.as_str(), "typing");
assert_eq!(symbol.name, "Literal");
}

#[test]
Expand All @@ -1738,8 +1759,8 @@ def overloaded_func[T](
let parts = get_parts(&t);

assert_eq!(parts.len(), 1);
assert_eq!(parts[0].0, "None");
assert!(parts[0].1.is_none(), "None should not have location");
assert_eq!(parts[0].text, "None");
assert!(parts[0].location.is_none(), "None should not have location");
}

#[test]
Expand All @@ -1765,7 +1786,11 @@ def overloaded_func[T](
for param in &["T", "U", "Ts"] {
assert_output_contains(&parts, param);
}
assert!(parts.iter().any(|(s, loc)| s == "[" && loc.is_none()));
assert!(
parts
.iter()
.any(|part| part.text == "[" && part.location.is_none())
);
assert!(parts_to_string(&parts).starts_with('['));
assert_output_contains(&parts, "](");
}
Expand Down Expand Up @@ -1796,7 +1821,7 @@ def overloaded_func[T](
for expected in &["Literal", "Color", "RED"] {
assert_output_contains(&parts, expected);
}
assert!(parts.iter().any(|(_, loc)| loc.is_some()));
assert!(parts.iter().any(|part| part.location.is_some()));
}

#[test]
Expand Down
Loading