Skip to content

Commit 1bd387e

Browse files
impl
1 parent e868606 commit 1bd387e

File tree

3 files changed

+239
-0
lines changed

3 files changed

+239
-0
lines changed

pyrefly/lib/alt/class/django.rs

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ use crate::alt::class::class_field::ClassField;
3131
use crate::alt::class::enums::VALUE_PROP;
3232
use crate::alt::types::class_metadata::ClassSynthesizedField;
3333
use crate::alt::types::class_metadata::ClassSynthesizedFields;
34+
use crate::binding::binding::ClassFieldDefinition;
35+
use crate::binding::binding::ExprOrBinding;
36+
use crate::binding::binding::KeyClassField;
3437
use crate::binding::binding::KeyExport;
3538
use crate::types::simplify::unions;
3639

@@ -67,6 +70,19 @@ fn has_keyword_true(call_expr: &ExprCall, name: &Name) -> bool {
6770
.is_some_and(|v| matches!(v, Expr::BooleanLiteral(lit) if lit.value))
6871
}
6972

73+
const ONE_TO_ONE_FIELD: Name = Name::new_static("OneToOneField");
74+
75+
const RELATED_NAME: Name = Name::new_static("related_name");
76+
77+
const RELATED_MANAGER: Name = Name::new_static("RelatedManager");
78+
79+
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
80+
enum DjangoRelationKind {
81+
ForeignKey,
82+
OneToOne,
83+
ManyToMany,
84+
}
85+
7086
impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
7187
pub fn get_django_field_type(
7288
&self,
@@ -225,6 +241,36 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
225241
Some(manager_type)
226242
}
227243

244+
// Get RelatedManager class from django stubs
245+
fn get_related_manager_type(&self, target_model_type: Type) -> Option<Type> {
246+
let django_related_module = ModuleName::django_models_fields_related_descriptors();
247+
let django_related_module_exports = self.exports.get(django_related_module).finding()?;
248+
if !django_related_module_exports
249+
.exports(self.exports)
250+
.contains_key(&RELATED_MANAGER)
251+
{
252+
return None;
253+
}
254+
255+
let manager_class_type =
256+
self.get_from_export(django_related_module, None, &KeyExport(RELATED_MANAGER));
257+
258+
let manager_class = match manager_class_type.as_ref() {
259+
Type::ClassDef(cls) => cls,
260+
_ => return None,
261+
};
262+
263+
let targs_vec = vec![target_model_type];
264+
let manager_type = self.specialize(
265+
manager_class,
266+
targs_vec,
267+
TextRange::default(),
268+
&self.error_swallower(),
269+
);
270+
271+
Some(manager_type)
272+
}
273+
228274
fn resolve_target(&self, to_expr: &Expr, class: &Class) -> Option<Type> {
229275
match to_expr {
230276
// Use expr_infer to resolve the name in the current scope
@@ -283,6 +329,13 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
283329
)
284330
}
285331

332+
pub fn is_one_to_one_field(&self, field: &Class) -> bool {
333+
field.has_toplevel_qname(
334+
ModuleName::django_models_fields_related().as_str(),
335+
ONE_TO_ONE_FIELD.as_str(),
336+
)
337+
}
338+
286339
pub fn get_django_enum_synthesized_fields(
287340
&self,
288341
cls: &Class,
@@ -547,6 +600,185 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
547600
}
548601
}
549602

603+
if let Some(reverse_fields) = self.get_django_reverse_relationship_synthesized_fields(cls) {
604+
for (name, field) in reverse_fields.fields() {
605+
fields.insert(name.clone(), field.clone());
606+
}
607+
}
608+
550609
Some(ClassSynthesizedFields::new(fields))
551610
}
611+
612+
fn get_django_reverse_relationship_synthesized_fields(
613+
&self,
614+
cls: &Class,
615+
) -> Option<ClassSynthesizedFields> {
616+
let mut fields = SmallMap::new();
617+
618+
for field_idx in self.bindings().keys::<KeyClassField>() {
619+
let binding = self.bindings().get(field_idx);
620+
let Some(source_class) = &self.get_idx(binding.class_idx).0 else {
621+
continue;
622+
};
623+
if !self.get_metadata_for_class(source_class).is_django_model() {
624+
continue;
625+
}
626+
627+
let expr = match &binding.definition {
628+
ClassFieldDefinition::AssignedInBody {
629+
value: ExprOrBinding::Expr(expr),
630+
..
631+
} => expr,
632+
_ => continue,
633+
};
634+
let call_expr = match expr.as_call_expr() {
635+
Some(call_expr) => call_expr,
636+
None => continue,
637+
};
638+
639+
let relation_kind = match self.django_relation_kind(expr) {
640+
Some(kind) => kind,
641+
None => continue,
642+
};
643+
644+
let Some(to_expr) = call_expr.arguments.args.first() else {
645+
continue;
646+
};
647+
let Some(target_type) = self.resolve_target(to_expr, source_class) else {
648+
continue;
649+
};
650+
let target_class = match &target_type {
651+
Type::ClassType(cls_type) => cls_type.class_object(),
652+
Type::ClassDef(class_def) => class_def,
653+
_ => continue,
654+
};
655+
if target_class != cls {
656+
continue;
657+
}
658+
659+
let related_name =
660+
match self.django_related_name(call_expr, source_class, relation_kind) {
661+
Some(name) => name,
662+
None => continue,
663+
};
664+
let related_type = match self.django_reverse_field_type(relation_kind, source_class) {
665+
Some(ty) => ty,
666+
None => continue,
667+
};
668+
669+
fields.insert(related_name, ClassSynthesizedField::new(related_type));
670+
}
671+
672+
if fields.is_empty() {
673+
None
674+
} else {
675+
Some(ClassSynthesizedFields::new(fields))
676+
}
677+
}
678+
679+
fn django_relation_kind(&self, expr: &Expr) -> Option<DjangoRelationKind> {
680+
let ty = self.expr_infer(expr, &self.error_swallower());
681+
let field_class = match &ty {
682+
Type::ClassType(cls) => cls.class_object(),
683+
Type::ClassDef(cls) => cls,
684+
_ => return None,
685+
};
686+
687+
if self.is_one_to_one_field(field_class) {
688+
Some(DjangoRelationKind::OneToOne)
689+
} else if self.is_many_to_many_field(field_class) {
690+
Some(DjangoRelationKind::ManyToMany)
691+
} else if self.is_foreign_key_field(field_class) {
692+
Some(DjangoRelationKind::ForeignKey)
693+
} else {
694+
None
695+
}
696+
}
697+
698+
fn django_reverse_field_type(
699+
&self,
700+
relation_kind: DjangoRelationKind,
701+
source_class: &Class,
702+
) -> Option<Type> {
703+
let source_type = self.instantiate(source_class);
704+
match relation_kind {
705+
DjangoRelationKind::ForeignKey => self.get_related_manager_type(source_type),
706+
DjangoRelationKind::ManyToMany => self.get_manager_type(source_type),
707+
DjangoRelationKind::OneToOne => Some(source_type),
708+
}
709+
}
710+
711+
fn django_related_name(
712+
&self,
713+
call_expr: &ExprCall,
714+
source_class: &Class,
715+
relation_kind: DjangoRelationKind,
716+
) -> Option<Name> {
717+
let mut related_name_expr = None;
718+
for keyword in &call_expr.arguments.keywords {
719+
if keyword
720+
.arg
721+
.as_ref()
722+
.is_some_and(|name| name.as_str() == RELATED_NAME.as_str())
723+
{
724+
related_name_expr = Some(&keyword.value);
725+
break;
726+
}
727+
}
728+
729+
match related_name_expr {
730+
None => Some(self.django_default_related_name(source_class, relation_kind)),
731+
Some(Expr::NoneLiteral(_)) => {
732+
Some(self.django_default_related_name(source_class, relation_kind))
733+
}
734+
Some(Expr::StringLiteral(lit)) => {
735+
self.format_related_name(lit.value.to_str(), source_class)
736+
}
737+
_ => None,
738+
}
739+
}
740+
741+
fn django_default_related_name(
742+
&self,
743+
source_class: &Class,
744+
relation_kind: DjangoRelationKind,
745+
) -> Name {
746+
let mut name = source_class.name().as_str().to_ascii_lowercase();
747+
if matches!(
748+
relation_kind,
749+
DjangoRelationKind::ForeignKey | DjangoRelationKind::ManyToMany
750+
) {
751+
name.push_str("_set");
752+
}
753+
Name::new(name)
754+
}
755+
756+
fn format_related_name(&self, raw: &str, source_class: &Class) -> Option<Name> {
757+
let trimmed = raw.trim();
758+
if trimmed.is_empty() || trimmed.ends_with('+') {
759+
return None;
760+
}
761+
762+
let class_name = source_class.name().as_str().to_ascii_lowercase();
763+
let module_name = source_class.module_name();
764+
let module_name = module_name.as_str();
765+
let mut module_parts = module_name.rsplit('.');
766+
let last = module_parts.next().unwrap_or(module_name);
767+
let app_label = if last == "models" {
768+
module_parts.next().unwrap_or(last)
769+
} else {
770+
last
771+
}
772+
.to_ascii_lowercase();
773+
774+
let substituted = trimmed
775+
.replace("%(class)s", &class_name)
776+
.replace("%(app_label)s", &app_label);
777+
778+
if substituted.is_empty() || substituted.ends_with('+') || substituted.contains('%') {
779+
return None;
780+
}
781+
782+
Some(Name::new(substituted))
783+
}
552784
}

pyrefly/lib/test/django/foreign_key.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ django_testcase!(
1616
from typing import assert_type
1717
1818
from django.db import models
19+
from django.db.models.fields.related_descriptors import RelatedManager
1920
2021
class Reporter(models.Model):
2122
full_name = models.CharField(max_length=70)
@@ -28,6 +29,9 @@ assert_type(article.reporter, Reporter)
2829
assert_type(article.reporter.full_name, str)
2930
assert_type(article.reporter_id, int)
3031
32+
reporter = Reporter()
33+
assert_type(reporter.article_set, RelatedManager[Article])
34+
3135
class B(Article):
3236
pass
3337

pyrefly/lib/test/django/many_to_many.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ assert_type(book.authors.all(), QuerySet[Author, Author])
3030
assert_type(book.authors.filter(name="Bob"), QuerySet[Author, Author])
3131
assert_type(book.authors.create(name="Alice"), Author)
3232
33+
author = Author()
34+
assert_type(author.books, ManyRelatedManager[Book, models.Model])
35+
3336
book.authors.add("wrong type") # E: Argument `Literal['wrong type']` is not assignable to parameter `*objs` with type `Author | int`
3437
"#,
3538
);

0 commit comments

Comments
 (0)