Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ def get_class_reference_for_type_id(
@abstractmethod
def does_circularly_reference_itself(self, type_id: ir_types.TypeId) -> bool: ...

@abstractmethod
def is_in_circular_cluster(self, type_id: ir_types.TypeId) -> bool: ...

@abstractmethod
def get_non_union_self_referencing_dependencies_from_types(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .type_reference_to_type_hint_converter import TypeReferenceToTypeHintConverter
from fern_python.codegen import AST, Filepath
from fern_python.declaration_referencer import AbstractDeclarationReferencer
from fern_python.generators.pydantic_model.circular_dependency_detector import CircularDependencyDetector
from fern_python.generators.pydantic_model.custom_config import UnionNamingVersions
from ordered_set import OrderedSet

Expand Down Expand Up @@ -59,6 +60,11 @@ def __init__(
defaultdict(OrderedSet)
)

self._circular_detector = CircularDependencyDetector(ir)
self._types_in_circular_clusters: Set[ir_types.TypeId] = set()
for cluster in self._circular_detector.get_all_circular_clusters():
self._types_in_circular_clusters.update(cluster)

for id, type in self.ir.types.items():
ordered_reference_types = OrderedSet(list(sorted(type.referenced_types)))
for referenced_id in ordered_reference_types:
Expand Down Expand Up @@ -169,6 +175,9 @@ def get_class_reference_for_type_id(
def does_circularly_reference_itself(self, type_id: ir_types.TypeId) -> bool:
return self.does_type_reference_other_type(type_id, type_id)

def is_in_circular_cluster(self, type_id: ir_types.TypeId) -> bool:
return type_id in self._types_in_circular_clusters

# This map goes from every non union type to a list of referenced types that circularly reference themselves
def get_non_union_self_referencing_dependencies_from_types(
self,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Dict, List, Set

if TYPE_CHECKING:
import fern.ir.resources as ir_types


class CircularDependencyDetector:
"""
Detects mutually recursive type clusters that should be consolidated into a single file.

This addresses the issue where deeply mutually recursive types (like AST operators)
cause circular import errors and Pydantic recursion issues when split across files.
"""

def __init__(self, ir: ir_types.IntermediateRepresentation):
self._ir = ir
self._type_clusters: Dict[ir_types.TypeId, Set[ir_types.TypeId]] = {}
self._computed = False

def get_type_cluster(self, type_id: ir_types.TypeId) -> Set[ir_types.TypeId]:
"""
Returns the cluster of mutually recursive types that includes the given type_id.
If the type is not part of a circular dependency cluster, returns a set with just the type_id.
"""
if not self._computed:
self._compute_clusters()
return self._type_clusters.get(type_id, {type_id})

def is_in_circular_cluster(self, type_id: ir_types.TypeId) -> bool:
"""Returns True if the type is part of a mutually recursive cluster with 2+ types."""
cluster = self.get_type_cluster(type_id)
return len(cluster) > 1

def _compute_clusters(self) -> None:
"""
Computes clusters of mutually recursive types using graph analysis.

Algorithm:
1. Build a directed graph of type references
2. Find strongly connected components (SCCs)
3. SCCs with 2+ nodes are circular dependency clusters
"""
graph: Dict[ir_types.TypeId, Set[ir_types.TypeId]] = {}
for type_id, type_decl in self._ir.types.items():
graph[type_id] = set(type_decl.referenced_types)

sccs = self._find_strongly_connected_components(graph)

for scc in sccs:
if len(scc) > 1:
scc_set = set(scc)
for type_id in scc:
self._type_clusters[type_id] = scc_set
else:
self._type_clusters[scc[0]] = {scc[0]}

self._computed = True

def _find_strongly_connected_components(
self, graph: Dict[ir_types.TypeId, Set[ir_types.TypeId]]
) -> List[List[ir_types.TypeId]]:
"""
Tarjan's algorithm for finding strongly connected components.
Returns a list of SCCs, where each SCC is a list of type_ids.
"""
index_counter = [0]
stack: List[ir_types.TypeId] = []
lowlinks: Dict[ir_types.TypeId, int] = {}
index: Dict[ir_types.TypeId, int] = {}
on_stack: Dict[ir_types.TypeId, bool] = {}
sccs: List[List[ir_types.TypeId]] = []

def strongconnect(node: ir_types.TypeId) -> None:
index[node] = index_counter[0]
lowlinks[node] = index_counter[0]
index_counter[0] += 1
stack.append(node)
on_stack[node] = True

successors = graph.get(node, set())
for successor in successors:
if successor not in index:
strongconnect(successor)
lowlinks[node] = min(lowlinks[node], lowlinks[successor])
elif on_stack.get(successor, False):
lowlinks[node] = min(lowlinks[node], index[successor])

if lowlinks[node] == index[node]:
scc: List[ir_types.TypeId] = []
while True:
successor = stack.pop()
on_stack[successor] = False
scc.append(successor)
if successor == node:
break
sccs.append(scc)

for node in graph.keys():
if node not in index:
strongconnect(node)

return sccs

def get_all_circular_clusters(self) -> List[Set[ir_types.TypeId]]:
"""Returns all circular dependency clusters (with 2+ types)."""
if not self._computed:
self._compute_clusters()

seen_clusters: Set[frozenset[ir_types.TypeId]] = set()
circular_clusters: List[Set[ir_types.TypeId]] = []

for type_id, cluster in self._type_clusters.items():
if len(cluster) > 1:
cluster_key = frozenset(cluster)
if cluster_key not in seen_clusters:
seen_clusters.add(cluster_key)
circular_clusters.append(cluster)

return circular_clusters
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,13 @@ def finish(self) -> None:
):
self._pydantic_model.add_partial_class()
self._get_validators_generator().add_validators()
if self._model_contains_forward_refs or self._force_update_forward_refs:

type_id_for_circular_check = self._type_id_for_forward_ref()
is_in_circular_cluster = type_id_for_circular_check is not None and self._context.is_in_circular_cluster(
type_id_for_circular_check
)

if (self._model_contains_forward_refs or self._force_update_forward_refs) and not is_in_circular_cluster:
self._pydantic_model.update_forward_refs()

# Acknowledge forward refs for extended models as well
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Literal, Tuple
from typing import TYPE_CHECKING, Literal, Set, Tuple

if TYPE_CHECKING:
from ..context.pydantic_generator_context import PydanticGeneratorContext

from ..context.pydantic_generator_context_impl import PydanticGeneratorContextImpl
from .circular_dependency_detector import CircularDependencyDetector
from .custom_config import PydanticModelCustomConfig
from .type_declaration_handler import (
TypeDeclarationHandler,
TypeDeclarationSnippetGeneratorBuilder,
)
from .type_declaration_referencer import TypeDeclarationReferencer
from fern_python.cli.abstract_generator import AbstractGenerator
from fern_python.codegen import Project
from fern_python.codegen import AST, Filepath, Project
from fern_python.generator_exec_wrapper import GeneratorExecWrapper
from fern_python.generators.pydantic_model.model_utilities import can_be_fern_model
from fern_python.snippet import SnippetRegistry, SnippetWriter
Expand Down Expand Up @@ -117,17 +118,37 @@ def generate_types(
snippet_registry: SnippetRegistry,
snippet_writer: SnippetWriter,
) -> None:
for type_to_generate in ir.types.values():
self._generate_type(
project,
type=type_to_generate,
detector = CircularDependencyDetector(ir)
circular_clusters = detector.get_all_circular_clusters()

types_in_clusters: Set[ir_types.TypeId] = set()
for cluster in circular_clusters:
types_in_clusters.update(cluster)

for cluster in circular_clusters:
self._generate_circular_cluster(
project=project,
cluster=cluster,
generator_exec_wrapper=generator_exec_wrapper,
custom_config=custom_config,
context=context,
snippet_registry=snippet_registry,
snippet_writer=snippet_writer,
ir=ir,
)

for type_to_generate in ir.types.values():
if type_to_generate.name.type_id not in types_in_clusters:
self._generate_type(
project,
type=type_to_generate,
generator_exec_wrapper=generator_exec_wrapper,
custom_config=custom_config,
context=context,
snippet_registry=snippet_registry,
snippet_writer=snippet_writer,
)

def _should_generate_typedict(self, context: "PydanticGeneratorContext", type_: ir_types.Type) -> bool:
return context.use_typeddict_requests and can_be_fern_model(type_, context.ir.types)

Expand Down Expand Up @@ -185,6 +206,94 @@ def _generate_type(
)
project.write_source_file(source_file=source_file, filepath=filepath)

def _generate_circular_cluster(
self,
project: Project,
cluster: Set[ir_types.TypeId],
generator_exec_wrapper: GeneratorExecWrapper,
custom_config: PydanticModelCustomConfig,
context: "PydanticGeneratorContext",
snippet_registry: SnippetRegistry,
snippet_writer: SnippetWriter,
ir: ir_types.IntermediateRepresentation,
) -> None:
sorted_cluster = sorted(cluster, key=lambda tid: str(tid))
canonical_type_id = sorted_cluster[0]

cluster_names = [ir.types[tid].name.name.snake_case.safe_name for tid in sorted_cluster]
print(f"[CIRCULAR CLUSTER] Generating cluster with {len(cluster_names)} types: {cluster_names}")

base_filepath = context.get_filepath_for_type_id(type_id=canonical_type_id, as_request=False)

canonical_decl = ir.types[canonical_type_id]
consolidated_filename = f"{canonical_decl.name.name.snake_case.safe_name}_all"

consolidated_filepath = Filepath(
directories=base_filepath.directories,
file=Filepath.FilepathPart(module_name=consolidated_filename),
)

source_file = context.source_file_factory.create(
project=project,
filepath=consolidated_filepath,
generator_exec_wrapper=generator_exec_wrapper,
)

for type_id in sorted_cluster:
type_decl = ir.types[type_id]

if self._should_generate_typedict(context=context, type_=type_decl.shape):
typeddict_filepath = context.get_filepath_for_type_id(type_id=type_id, as_request=True)
typeddict_source_file = context.source_file_factory.create(
project=project, filepath=typeddict_filepath, generator_exec_wrapper=generator_exec_wrapper
)

typeddict_handler = TypeDeclarationHandler(
declaration=type_decl,
context=context,
custom_config=custom_config,
source_file=typeddict_source_file,
snippet_writer=snippet_writer,
generate_typeddict_request=True,
)
typeddict_handler.run()

project.write_source_file(source_file=typeddict_source_file, filepath=typeddict_filepath)

type_declaration_handler = TypeDeclarationHandler(
declaration=type_decl,
context=context,
custom_config=custom_config,
source_file=source_file,
snippet_writer=snippet_writer,
generate_typeddict_request=False,
)
generated_type = type_declaration_handler.run()

if generated_type.snippet is not None:
snippet_registry.register_snippet(
type_id=type_decl.name.type_id,
expr=generated_type.snippet,
)

project.write_source_file(source_file=source_file, filepath=consolidated_filepath)

for type_id in sorted_cluster:
type_decl = ir.types[type_id]

# Create stub file for the regular type (as_request=False)
individual_filepath = context.get_filepath_for_type_id(type_id=type_id, as_request=False)
class_name = context.get_class_name_for_type_id(type_id=type_id, as_request=False)

stub_content = (
f"# This file was auto-generated by Fern from our API Definition.\n\n"
f"from .{consolidated_filename} import {class_name} as {class_name}\n\n"
f'__all__ = ["{class_name}"]\n'
)

stub_filepath = project.get_relative_source_file_filepath(individual_filepath)
project.add_file(stub_filepath, stub_content)

def get_sorted_modules(self) -> None:
return None

Expand Down
Loading
Loading