Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def get_class_reference_for_type_id(
@abstractmethod
def does_circularly_reference_itself(self, type_id: ir_types.TypeId) -> bool: ...

@abstractmethod
def is_in_circular_cluster(self, type_id: ir_types.TypeId) -> bool: ...

@abstractmethod
def get_non_union_self_referencing_dependencies_from_types(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from fern_python.codegen import AST, Filepath
from fern_python.declaration_referencer import AbstractDeclarationReferencer
from fern_python.generators.pydantic_model.custom_config import UnionNamingVersions
from fern_python.generators.pydantic_model.circular_dependency_detector import CircularDependencyDetector
from ordered_set import OrderedSet

import fern.ir.resources as ir_types
Expand Down Expand Up @@ -59,6 +60,11 @@ def __init__(
defaultdict(OrderedSet)
)

self._circular_detector = CircularDependencyDetector(ir)
self._types_in_circular_clusters: Set[ir_types.TypeId] = set()
for cluster in self._circular_detector.get_all_circular_clusters():
self._types_in_circular_clusters.update(cluster)

for id, type in self.ir.types.items():
ordered_reference_types = OrderedSet(list(sorted(type.referenced_types)))
for referenced_id in ordered_reference_types:
Expand Down Expand Up @@ -169,6 +175,9 @@ def get_class_reference_for_type_id(
def does_circularly_reference_itself(self, type_id: ir_types.TypeId) -> bool:
return self.does_type_reference_other_type(type_id, type_id)

def is_in_circular_cluster(self, type_id: ir_types.TypeId) -> bool:
return type_id in self._types_in_circular_clusters

# This map goes from every non union type to a list of referenced types that circularly reference themselves
def get_non_union_self_referencing_dependencies_from_types(
self,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from __future__ import annotations

from typing import Dict, List, Set, TYPE_CHECKING

if TYPE_CHECKING:
import fern.ir.resources as ir_types


class CircularDependencyDetector:
"""
Detects mutually recursive type clusters that should be consolidated into a single file.

This addresses the issue where deeply mutually recursive types (like AST operators)
cause circular import errors and Pydantic recursion issues when split across files.
"""

def __init__(self, ir: ir_types.IntermediateRepresentation):
self._ir = ir
self._type_clusters: Dict[ir_types.TypeId, Set[ir_types.TypeId]] = {}
self._computed = False

def get_type_cluster(self, type_id: ir_types.TypeId) -> Set[ir_types.TypeId]:
"""
Returns the cluster of mutually recursive types that includes the given type_id.
If the type is not part of a circular dependency cluster, returns a set with just the type_id.
"""
if not self._computed:
self._compute_clusters()
return self._type_clusters.get(type_id, {type_id})

def is_in_circular_cluster(self, type_id: ir_types.TypeId) -> bool:
"""Returns True if the type is part of a mutually recursive cluster with 2+ types."""
cluster = self.get_type_cluster(type_id)
return len(cluster) > 1

def _compute_clusters(self) -> None:
"""
Computes clusters of mutually recursive types using graph analysis.

Algorithm:
1. Build a directed graph of type references
2. Find strongly connected components (SCCs)
3. SCCs with 2+ nodes are circular dependency clusters
"""
graph: Dict[ir_types.TypeId, Set[ir_types.TypeId]] = {}
for type_id, type_decl in self._ir.types.items():
graph[type_id] = set(type_decl.referenced_types)

sccs = self._find_strongly_connected_components(graph)

for scc in sccs:
if len(scc) > 1:
scc_set = set(scc)
for type_id in scc:
self._type_clusters[type_id] = scc_set
else:
self._type_clusters[scc[0]] = {scc[0]}

self._computed = True

def _find_strongly_connected_components(
self, graph: Dict[ir_types.TypeId, Set[ir_types.TypeId]]
) -> List[List[ir_types.TypeId]]:
"""
Tarjan's algorithm for finding strongly connected components.
Returns a list of SCCs, where each SCC is a list of type_ids.
"""
index_counter = [0]
stack: List[ir_types.TypeId] = []
lowlinks: Dict[ir_types.TypeId, int] = {}
index: Dict[ir_types.TypeId, int] = {}
on_stack: Dict[ir_types.TypeId, bool] = {}
sccs: List[List[ir_types.TypeId]] = []

def strongconnect(node: ir_types.TypeId) -> None:
index[node] = index_counter[0]
lowlinks[node] = index_counter[0]
index_counter[0] += 1
stack.append(node)
on_stack[node] = True

successors = graph.get(node, set())
for successor in successors:
if successor not in index:
strongconnect(successor)
lowlinks[node] = min(lowlinks[node], lowlinks[successor])
elif on_stack.get(successor, False):
lowlinks[node] = min(lowlinks[node], index[successor])

if lowlinks[node] == index[node]:
scc: List[ir_types.TypeId] = []
while True:
successor = stack.pop()
on_stack[successor] = False
scc.append(successor)
if successor == node:
break
sccs.append(scc)

for node in graph.keys():
if node not in index:
strongconnect(node)

return sccs

def get_all_circular_clusters(self) -> List[Set[ir_types.TypeId]]:
"""Returns all circular dependency clusters (with 2+ types)."""
if not self._computed:
self._compute_clusters()

seen_clusters: Set[frozenset[ir_types.TypeId]] = set()
circular_clusters: List[Set[ir_types.TypeId]] = []

for type_id, cluster in self._type_clusters.items():
if len(cluster) > 1:
cluster_key = frozenset(cluster)
if cluster_key not in seen_clusters:
seen_clusters.add(cluster_key)
circular_clusters.append(cluster)

return circular_clusters
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,14 @@ def finish(self) -> None:
):
self._pydantic_model.add_partial_class()
self._get_validators_generator().add_validators()
if self._model_contains_forward_refs or self._force_update_forward_refs:

type_id_for_circular_check = self._type_id_for_forward_ref()
is_in_circular_cluster = (
type_id_for_circular_check is not None
and self._context.is_in_circular_cluster(type_id_for_circular_check)
)

if (self._model_contains_forward_refs or self._force_update_forward_refs) and not is_in_circular_cluster:
self._pydantic_model.update_forward_refs()

# Acknowledge forward refs for extended models as well
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Literal, Tuple
from typing import TYPE_CHECKING, Literal, Tuple, Set

if TYPE_CHECKING:
from ..context.pydantic_generator_context import PydanticGeneratorContext
Expand All @@ -12,8 +12,9 @@
TypeDeclarationSnippetGeneratorBuilder,
)
from .type_declaration_referencer import TypeDeclarationReferencer
from .circular_dependency_detector import CircularDependencyDetector
from fern_python.cli.abstract_generator import AbstractGenerator
from fern_python.codegen import Project
from fern_python.codegen import Project, Filepath, AST
from fern_python.generator_exec_wrapper import GeneratorExecWrapper
from fern_python.generators.pydantic_model.model_utilities import can_be_fern_model
from fern_python.snippet import SnippetRegistry, SnippetWriter
Expand Down Expand Up @@ -117,17 +118,37 @@ def generate_types(
snippet_registry: SnippetRegistry,
snippet_writer: SnippetWriter,
) -> None:
for type_to_generate in ir.types.values():
self._generate_type(
project,
type=type_to_generate,
detector = CircularDependencyDetector(ir)
circular_clusters = detector.get_all_circular_clusters()

types_in_clusters: Set[ir_types.TypeId] = set()
for cluster in circular_clusters:
types_in_clusters.update(cluster)

for cluster in circular_clusters:
self._generate_circular_cluster(
project=project,
cluster=cluster,
generator_exec_wrapper=generator_exec_wrapper,
custom_config=custom_config,
context=context,
snippet_registry=snippet_registry,
snippet_writer=snippet_writer,
ir=ir,
)

for type_to_generate in ir.types.values():
if type_to_generate.name.type_id not in types_in_clusters:
self._generate_type(
project,
type=type_to_generate,
generator_exec_wrapper=generator_exec_wrapper,
custom_config=custom_config,
context=context,
snippet_registry=snippet_registry,
snippet_writer=snippet_writer,
)

def _should_generate_typedict(self, context: "PydanticGeneratorContext", type_: ir_types.Type) -> bool:
return context.use_typeddict_requests and can_be_fern_model(type_, context.ir.types)

Expand Down Expand Up @@ -185,6 +206,57 @@ def _generate_type(
)
project.write_source_file(source_file=source_file, filepath=filepath)

def _generate_circular_cluster(
self,
project: Project,
cluster: Set[ir_types.TypeId],
generator_exec_wrapper: GeneratorExecWrapper,
custom_config: PydanticModelCustomConfig,
context: "PydanticGeneratorContext",
snippet_registry: SnippetRegistry,
snippet_writer: SnippetWriter,
ir: ir_types.IntermediateRepresentation,
) -> None:
sorted_cluster = sorted(cluster, key=lambda tid: str(tid))
canonical_type_id = sorted_cluster[0]

base_filepath = context.get_filepath_for_type_id(type_id=canonical_type_id, as_request=False)

canonical_decl = ir.types[canonical_type_id]
consolidated_filename = f"{canonical_decl.name.name.snake_case.safe_name}_all"

consolidated_filepath = Filepath(
directories=base_filepath.directories,
file=Filepath.FilepathPart(module_name=consolidated_filename),
)

source_file = context.source_file_factory.create(
project=project,
filepath=consolidated_filepath,
generator_exec_wrapper=generator_exec_wrapper,
)

for type_id in sorted_cluster:
type_decl = ir.types[type_id]

type_declaration_handler = TypeDeclarationHandler(
declaration=type_decl,
context=context,
custom_config=custom_config,
source_file=source_file,
snippet_writer=snippet_writer,
generate_typeddict_request=False,
)
generated_type = type_declaration_handler.run()

if generated_type.snippet is not None:
snippet_registry.register_snippet(
type_id=type_decl.name.type_id,
expr=generated_type.snippet,
)

project.write_source_file(source_file=source_file, filepath=consolidated_filepath)

def get_sorted_modules(self) -> None:
return None

Expand Down
5 changes: 5 additions & 0 deletions seed/python-sdk/seed.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ fixtures:
- customConfig:
use_inheritance_for_extended_models: false
outputFolder: no-inheritance-for-extended-models
deep-circular-references:
- customConfig:
pydantic_config:
version: v2
outputFolder: no-custom-config
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add fixtures for v1 and v1_on_v2 (or whatever that's called)

file-download:
- customConfig:
use_inheritance_for_extended_models: false
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
exports:
types:
- AndOperator
- OrOperator
- EqualsOperator
- GreaterThanOperator
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
name: api
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
types:
# Boolean operators (AND/OR) that can contain comparison operators
AndOperator:
properties:
children: list<AndOperatorChild>

AndOperatorChild:
union:
and_operator: AndOperator
or_operator: OrOperator
eq_operator: EqualsOperator
gt_operator: GreaterThanOperator
boolean_literal: boolean

OrOperator:
properties:
children: list<OrOperatorChild>

OrOperatorChild:
union:
and_operator: AndOperator
or_operator: OrOperator
eq_operator: EqualsOperator
gt_operator: GreaterThanOperator
boolean_literal: boolean

# Comparison operators that can have boolean operators on left/right
EqualsOperator:
properties:
left: EqualsOperatorLeft
right: EqualsOperatorRight

EqualsOperatorLeft:
union:
and_operator: AndOperator
or_operator: OrOperator
eq_operator: EqualsOperator
gt_operator: GreaterThanOperator
number_literal: double
string_literal: string

EqualsOperatorRight:
union:
and_operator: AndOperator
or_operator: OrOperator
eq_operator: EqualsOperator
gt_operator: GreaterThanOperator
number_literal: double
string_literal: string

GreaterThanOperator:
properties:
left: GreaterThanOperatorLeft
right: GreaterThanOperatorRight

GreaterThanOperatorLeft:
union:
and_operator: AndOperator
or_operator: OrOperator
eq_operator: EqualsOperator
gt_operator: GreaterThanOperator
number_literal: double

GreaterThanOperatorRight:
union:
and_operator: AndOperator
or_operator: OrOperator
eq_operator: EqualsOperator
gt_operator: GreaterThanOperator
number_literal: double
Loading
Loading