Skip to content

Commit f88c53f

Browse files
impl
1 parent 10192a8 commit f88c53f

File tree

3 files changed

+236
-0
lines changed

3 files changed

+236
-0
lines changed

pyrefly/lib/alt/class/django.rs

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ use crate::alt::class::class_field::ClassField;
3131
use crate::alt::class::enums::VALUE_PROP;
3232
use crate::alt::types::class_metadata::ClassSynthesizedField;
3333
use crate::alt::types::class_metadata::ClassSynthesizedFields;
34+
use crate::binding::binding::ClassFieldDefinition;
35+
use crate::binding::binding::ExprOrBinding;
36+
use crate::binding::binding::KeyClassField;
3437
use crate::binding::binding::KeyExport;
3538
use crate::types::simplify::unions;
3639

@@ -48,6 +51,16 @@ const NULL: Name = Name::new_static("null");
4851
const MANY_TO_MANY_FIELD: Name = Name::new_static("ManyToManyField");
4952
const MODEL: Name = Name::new_static("Model");
5053
const MANYRELATEDMANAGER: Name = Name::new_static("ManyRelatedManager");
54+
const ONE_TO_ONE_FIELD: Name = Name::new_static("OneToOneField");
55+
const RELATED_NAME: Name = Name::new_static("related_name");
56+
const RELATED_MANAGER: Name = Name::new_static("RelatedManager");
57+
58+
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
59+
enum DjangoRelationKind {
60+
ForeignKey,
61+
OneToOne,
62+
ManyToMany,
63+
}
5164

5265
impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
5366
pub fn get_django_field_type(
@@ -184,6 +197,36 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
184197
Some(manager_type)
185198
}
186199

200+
// Get RelatedManager class from django stubs
201+
fn get_related_manager_type(&self, target_model_type: Type) -> Option<Type> {
202+
let django_related_module = ModuleName::django_models_fields_related_descriptors();
203+
let django_related_module_exports = self.exports.get(django_related_module).finding()?;
204+
if !django_related_module_exports
205+
.exports(self.exports)
206+
.contains_key(&RELATED_MANAGER)
207+
{
208+
return None;
209+
}
210+
211+
let manager_class_type =
212+
self.get_from_export(django_related_module, None, &KeyExport(RELATED_MANAGER));
213+
214+
let manager_class = match manager_class_type.as_ref() {
215+
Type::ClassDef(cls) => cls,
216+
_ => return None,
217+
};
218+
219+
let targs_vec = vec![target_model_type];
220+
let manager_type = self.specialize(
221+
manager_class,
222+
targs_vec,
223+
TextRange::default(),
224+
&self.error_swallower(),
225+
);
226+
227+
Some(manager_type)
228+
}
229+
187230
fn resolve_target(&self, to_expr: &Expr, class: &Class) -> Option<Type> {
188231
match to_expr {
189232
// Use expr_infer to resolve the name in the current scope
@@ -242,6 +285,13 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
242285
)
243286
}
244287

288+
pub fn is_one_to_one_field(&self, field: &Class) -> bool {
289+
field.has_toplevel_qname(
290+
ModuleName::django_models_fields_related().as_str(),
291+
ONE_TO_ONE_FIELD.as_str(),
292+
)
293+
}
294+
245295
pub fn get_django_enum_synthesized_fields(
246296
&self,
247297
cls: &Class,
@@ -481,6 +531,185 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
481531
}
482532
}
483533

534+
if let Some(reverse_fields) = self.get_django_reverse_relationship_synthesized_fields(cls) {
535+
for (name, field) in reverse_fields.fields() {
536+
fields.insert(name.clone(), field.clone());
537+
}
538+
}
539+
484540
Some(ClassSynthesizedFields::new(fields))
485541
}
542+
543+
fn get_django_reverse_relationship_synthesized_fields(
544+
&self,
545+
cls: &Class,
546+
) -> Option<ClassSynthesizedFields> {
547+
let mut fields = SmallMap::new();
548+
549+
for field_idx in self.bindings().keys::<KeyClassField>() {
550+
let binding = self.bindings().get(field_idx);
551+
let Some(source_class) = &self.get_idx(binding.class_idx).0 else {
552+
continue;
553+
};
554+
if !self.get_metadata_for_class(source_class).is_django_model() {
555+
continue;
556+
}
557+
558+
let expr = match &binding.definition {
559+
ClassFieldDefinition::AssignedInBody {
560+
value: ExprOrBinding::Expr(expr),
561+
..
562+
} => expr,
563+
_ => continue,
564+
};
565+
let call_expr = match expr.as_call_expr() {
566+
Some(call_expr) => call_expr,
567+
None => continue,
568+
};
569+
570+
let relation_kind = match self.django_relation_kind(expr) {
571+
Some(kind) => kind,
572+
None => continue,
573+
};
574+
575+
let Some(to_expr) = call_expr.arguments.args.first() else {
576+
continue;
577+
};
578+
let Some(target_type) = self.resolve_target(to_expr, source_class) else {
579+
continue;
580+
};
581+
let target_class = match &target_type {
582+
Type::ClassType(cls_type) => cls_type.class_object(),
583+
Type::ClassDef(class_def) => class_def,
584+
_ => continue,
585+
};
586+
if target_class != cls {
587+
continue;
588+
}
589+
590+
let related_name =
591+
match self.django_related_name(call_expr, source_class, relation_kind) {
592+
Some(name) => name,
593+
None => continue,
594+
};
595+
let related_type = match self.django_reverse_field_type(relation_kind, source_class) {
596+
Some(ty) => ty,
597+
None => continue,
598+
};
599+
600+
fields.insert(related_name, ClassSynthesizedField::new(related_type));
601+
}
602+
603+
if fields.is_empty() {
604+
None
605+
} else {
606+
Some(ClassSynthesizedFields::new(fields))
607+
}
608+
}
609+
610+
fn django_relation_kind(&self, expr: &Expr) -> Option<DjangoRelationKind> {
611+
let ty = self.expr_infer(expr, &self.error_swallower());
612+
let field_class = match &ty {
613+
Type::ClassType(cls) => cls.class_object(),
614+
Type::ClassDef(cls) => cls,
615+
_ => return None,
616+
};
617+
618+
if self.is_one_to_one_field(field_class) {
619+
Some(DjangoRelationKind::OneToOne)
620+
} else if self.is_many_to_many_field(field_class) {
621+
Some(DjangoRelationKind::ManyToMany)
622+
} else if self.is_foreign_key_field(field_class) {
623+
Some(DjangoRelationKind::ForeignKey)
624+
} else {
625+
None
626+
}
627+
}
628+
629+
fn django_reverse_field_type(
630+
&self,
631+
relation_kind: DjangoRelationKind,
632+
source_class: &Class,
633+
) -> Option<Type> {
634+
let source_type = self.instantiate(source_class);
635+
match relation_kind {
636+
DjangoRelationKind::ForeignKey => self.get_related_manager_type(source_type),
637+
DjangoRelationKind::ManyToMany => self.get_manager_type(source_type),
638+
DjangoRelationKind::OneToOne => Some(source_type),
639+
}
640+
}
641+
642+
fn django_related_name(
643+
&self,
644+
call_expr: &ExprCall,
645+
source_class: &Class,
646+
relation_kind: DjangoRelationKind,
647+
) -> Option<Name> {
648+
let mut related_name_expr = None;
649+
for keyword in &call_expr.arguments.keywords {
650+
if keyword
651+
.arg
652+
.as_ref()
653+
.is_some_and(|name| name.as_str() == RELATED_NAME.as_str())
654+
{
655+
related_name_expr = Some(&keyword.value);
656+
break;
657+
}
658+
}
659+
660+
match related_name_expr {
661+
None => Some(self.django_default_related_name(source_class, relation_kind)),
662+
Some(Expr::NoneLiteral(_)) => {
663+
Some(self.django_default_related_name(source_class, relation_kind))
664+
}
665+
Some(Expr::StringLiteral(lit)) => {
666+
self.format_related_name(lit.value.to_str(), source_class)
667+
}
668+
_ => None,
669+
}
670+
}
671+
672+
fn django_default_related_name(
673+
&self,
674+
source_class: &Class,
675+
relation_kind: DjangoRelationKind,
676+
) -> Name {
677+
let mut name = source_class.name().as_str().to_ascii_lowercase();
678+
if matches!(
679+
relation_kind,
680+
DjangoRelationKind::ForeignKey | DjangoRelationKind::ManyToMany
681+
) {
682+
name.push_str("_set");
683+
}
684+
Name::new(name)
685+
}
686+
687+
fn format_related_name(&self, raw: &str, source_class: &Class) -> Option<Name> {
688+
let trimmed = raw.trim();
689+
if trimmed.is_empty() || trimmed.ends_with('+') {
690+
return None;
691+
}
692+
693+
let class_name = source_class.name().as_str().to_ascii_lowercase();
694+
let module_name = source_class.module_name();
695+
let module_name = module_name.as_str();
696+
let mut module_parts = module_name.rsplit('.');
697+
let last = module_parts.next().unwrap_or(module_name);
698+
let app_label = if last == "models" {
699+
module_parts.next().unwrap_or(last)
700+
} else {
701+
last
702+
}
703+
.to_ascii_lowercase();
704+
705+
let substituted = trimmed
706+
.replace("%(class)s", &class_name)
707+
.replace("%(app_label)s", &app_label);
708+
709+
if substituted.is_empty() || substituted.ends_with('+') || substituted.contains('%') {
710+
return None;
711+
}
712+
713+
Some(Name::new(substituted))
714+
}
486715
}

pyrefly/lib/test/django/foreign_key.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ django_testcase!(
1616
from typing import assert_type
1717
1818
from django.db import models
19+
from django.db.models.fields.related_descriptors import RelatedManager
1920
2021
class Reporter(models.Model):
2122
full_name = models.CharField(max_length=70)
@@ -28,6 +29,9 @@ assert_type(article.reporter, Reporter)
2829
assert_type(article.reporter.full_name, str)
2930
assert_type(article.reporter_id, int)
3031
32+
reporter = Reporter()
33+
assert_type(reporter.article_set, RelatedManager[Article])
34+
3135
class B(Article):
3236
pass
3337

pyrefly/lib/test/django/many_to_many.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ assert_type(book.authors.all(), QuerySet[Author, Author])
3030
assert_type(book.authors.filter(name="Bob"), QuerySet[Author, Author])
3131
assert_type(book.authors.create(name="Alice"), Author)
3232
33+
author = Author()
34+
assert_type(author.books, ManyRelatedManager[Book, models.Model])
35+
3336
book.authors.add("wrong type") # E: Argument `Literal['wrong type']` is not assignable to parameter `*objs` with type `Author | int`
3437
"#,
3538
);

0 commit comments

Comments
 (0)