Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from abc import ABC, abstractmethod
from typing import Optional
from typing import Optional, TYPE_CHECKING

from fern_python.codegen import AST
from fern_python.snippet import SnippetWriter

if TYPE_CHECKING:
from fern_python.snippet.recursion_guard import RecursionGuard


class AbstractTypeSnippetGenerator(ABC):
def __init__(
Expand All @@ -13,4 +16,4 @@ def __init__(
self.snippet_writer = snippet_writer

@abstractmethod
def generate_snippet(self) -> Optional[AST.Expression]: ...
def generate_snippet(self, recursion_guard: Optional["RecursionGuard"] = None) -> Optional[AST.Expression]: ...
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC
from typing import Optional
from typing import Optional, TYPE_CHECKING

from ...context.pydantic_generator_context import PydanticGeneratorContext
from ..custom_config import PydanticModelCustomConfig
Expand All @@ -12,6 +12,9 @@

import fern.ir.resources as ir_types

if TYPE_CHECKING:
from fern_python.snippet.recursion_guard import RecursionGuard


class AbstractAliasGenerator(AbstractTypeGenerator, ABC):
def __init__(
Expand Down Expand Up @@ -52,9 +55,10 @@ def __init__(
self.as_request = as_request
self.example = example

def generate_snippet(self) -> Optional[AST.Expression]:
def generate_snippet(self, recursion_guard: Optional["RecursionGuard"] = None) -> Optional[AST.Expression]:
return self.snippet_writer.get_snippet_for_example_type_reference(
example_type_reference=self.example.value,
use_typeddict_request=self.use_typeddict_request,
as_request=self.as_request,
recursion_guard=recursion_guard,
)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Optional, Union, TYPE_CHECKING

from ...context.pydantic_generator_context import PydanticGeneratorContext
from ..custom_config import PydanticModelCustomConfig
Expand All @@ -12,6 +12,9 @@

import fern.ir.resources as ir_types

if TYPE_CHECKING:
from fern_python.snippet.recursion_guard import RecursionGuard


# Note enums are the same for both pydantic models and typeddicts os the generator is not multiplexed
class EnumGenerator(AbstractTypeGenerator):
Expand Down Expand Up @@ -166,7 +169,7 @@ def __init__(
self.name = name
self.example = example.value if isinstance(example, ir_types.ExampleEnumType) else example

def generate_snippet(self) -> AST.Expression:
def generate_snippet(self, recursion_guard: Optional["RecursionGuard"] = None) -> AST.Expression:
class_reference = self.snippet_writer.get_class_reference_for_declared_type_name(
name=self.name,
as_request=False,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List, Optional, TYPE_CHECKING

from ....context.pydantic_generator_context import PydanticGeneratorContext
from ...custom_config import PydanticModelCustomConfig
Expand All @@ -13,6 +13,9 @@

import fern.ir.resources as ir_types

if TYPE_CHECKING:
from fern_python.snippet.recursion_guard import RecursionGuard


class PydanticModelObjectGenerator(AbstractObjectGenerator):
def __init__(
Expand Down Expand Up @@ -97,7 +100,7 @@ def __init__(
example=example,
)

def generate_snippet(self) -> AST.Expression:
def generate_snippet(self, recursion_guard: Optional["RecursionGuard"] = None) -> AST.Expression:
return AST.Expression(
AST.ClassInstantiation(
class_=self.snippet_writer.get_class_reference_for_declared_type_name(
Expand All @@ -110,6 +113,7 @@ def generate_snippet(self) -> AST.Expression:
use_typeddict_request=False,
as_request=False,
in_typeddict=False,
recursion_guard=recursion_guard,
),
),
)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, TYPE_CHECKING

from ...context.pydantic_generator_context import PydanticGeneratorContext
from .enum_generator import EnumSnippetGenerator
Expand Down Expand Up @@ -27,6 +27,9 @@

import fern.ir.resources as ir_types

if TYPE_CHECKING:
from fern_python.snippet.recursion_guard import RecursionGuard


class TypeDeclarationSnippetGeneratorBuilder:
def __init__(
Expand All @@ -41,78 +44,78 @@ def get_generator(
self,
) -> TypeDeclarationSnippetGenerator:
return TypeDeclarationSnippetGenerator(
alias=lambda example: self._get_alias_snippet_generator(example),
enum=lambda name, example: EnumSnippetGenerator(
alias=lambda example, recursion_guard=None: self._get_alias_snippet_generator(example, recursion_guard),
enum=lambda name, example, recursion_guard=None: EnumSnippetGenerator(
snippet_writer=self._snippet_writer,
name=name,
example=example,
use_str_enums=self._context.use_str_enums,
).generate_snippet(),
object=lambda name, example: self._get_object_snippet_generator(name, example),
discriminated_union=lambda name, example: self._get_discriminated_union_snippet_generator(name, example),
undiscriminated_union=lambda name, example: self._get_undiscriminated_union_snippet_generator(
name, example
).generate_snippet(recursion_guard),
object=lambda name, example, recursion_guard=None: self._get_object_snippet_generator(name, example, recursion_guard),
discriminated_union=lambda name, example, recursion_guard=None: self._get_discriminated_union_snippet_generator(name, example, recursion_guard),
undiscriminated_union=lambda name, example, recursion_guard=None: self._get_undiscriminated_union_snippet_generator(
name, example, recursion_guard
),
)

def _get_alias_snippet_generator(self, example: ir_types.ExampleAliasType) -> Optional[AST.Expression]:
def _get_alias_snippet_generator(self, example: ir_types.ExampleAliasType, recursion_guard: Optional["RecursionGuard"] = None) -> Optional[AST.Expression]:
if self._context.use_typeddict_requests:
return TypedDictAliasSnippetGenerator(
snippet_writer=self._snippet_writer,
example=example,
).generate_snippet()
).generate_snippet(recursion_guard)

return PydanticModelAliasSnippetGenerator(
snippet_writer=self._snippet_writer,
example=example,
).generate_snippet()
).generate_snippet(recursion_guard)

def _get_object_snippet_generator(
self, name: ir_types.DeclaredTypeName, example: ir_types.ExampleObjectType
self, name: ir_types.DeclaredTypeName, example: ir_types.ExampleObjectType, recursion_guard: Optional["RecursionGuard"] = None
) -> AST.Expression:
if self._context.use_typeddict_requests:
return TypeddictObjectSnippetGenerator(
name=name,
snippet_writer=self._snippet_writer,
example=example,
).generate_snippet()
).generate_snippet(recursion_guard)

return PydanticModelObjectSnippetGenerator(
name=name,
snippet_writer=self._snippet_writer,
example=example,
).generate_snippet()
).generate_snippet(recursion_guard)

def _get_discriminated_union_snippet_generator(
self, name: ir_types.DeclaredTypeName, example: ir_types.ExampleUnionType
self, name: ir_types.DeclaredTypeName, example: ir_types.ExampleUnionType, recursion_guard: Optional["RecursionGuard"] = None
) -> AST.Expression:
if self._context.use_typeddict_requests:
return TypeddictDiscriminatedUnionSnippetGenerator(
name=name,
snippet_writer=self._snippet_writer,
example=example,
union_naming_version=self._context.union_naming_version,
).generate_snippet()
).generate_snippet(recursion_guard)

return PydanticModelDiscriminatedUnionSnippetGenerator(
name=name,
snippet_writer=self._snippet_writer,
example=example,
union_naming_version=self._context.union_naming_version,
).generate_snippet()
).generate_snippet(recursion_guard)

def _get_undiscriminated_union_snippet_generator(
self, name: ir_types.DeclaredTypeName, example: ir_types.ExampleUndiscriminatedUnionType
self, name: ir_types.DeclaredTypeName, example: ir_types.ExampleUndiscriminatedUnionType, recursion_guard: Optional["RecursionGuard"] = None
) -> Optional[AST.Expression]:
if self._context.use_typeddict_requests:
return TypeddictUndiscriminatedUnionSnippetGenerator(
name=name,
snippet_writer=self._snippet_writer,
example=example,
).generate_snippet()
).generate_snippet(recursion_guard)

return PydanticModelUndiscriminatedUnionSnippetGenerator(
name=name,
snippet_writer=self._snippet_writer,
example=example,
).generate_snippet()
).generate_snippet(recursion_guard)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List, Optional, TYPE_CHECKING

from ....context.pydantic_generator_context import PydanticGeneratorContext
from ...custom_config import PydanticModelCustomConfig
Expand All @@ -13,6 +13,9 @@

import fern.ir.resources as ir_types

if TYPE_CHECKING:
from fern_python.snippet.recursion_guard import RecursionGuard


class TypeddictObjectGenerator(AbstractObjectGenerator):
def __init__(
Expand Down Expand Up @@ -72,5 +75,5 @@ def __init__(
example=example,
)

def generate_snippet(self) -> AST.Expression:
return FernTypedDict.type_to_snippet(example=self.example, snippet_writer=self.snippet_writer)
def generate_snippet(self, recursion_guard: Optional["RecursionGuard"] = None) -> AST.Expression:
return FernTypedDict.type_to_snippet(example=self.example, snippet_writer=self.snippet_writer, recursion_guard=recursion_guard)
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC
from dataclasses import dataclass
from typing import List, Optional
from typing import List, Optional, TYPE_CHECKING

from ...context.pydantic_generator_context import PydanticGeneratorContext
from ..custom_config import PydanticModelCustomConfig
Expand All @@ -13,6 +13,9 @@

import fern.ir.resources as ir_types

if TYPE_CHECKING:
from fern_python.snippet.recursion_guard import RecursionGuard


@dataclass(frozen=True)
class CycleAwareMemberType:
Expand Down Expand Up @@ -71,9 +74,10 @@ def __init__(
self.as_request = as_request
self.use_typeddict_request = use_typeddict_request

def generate_snippet(self) -> Optional[AST.Expression]:
def generate_snippet(self, recursion_guard: Optional["RecursionGuard"] = None) -> Optional[AST.Expression]:
return self.snippet_writer.get_snippet_for_example_type_reference(
example_type_reference=self.example.single_union_type,
use_typeddict_request=self.use_typeddict_request,
as_request=self.as_request,
recursion_guard=recursion_guard,
)
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dataclasses import dataclass
from types import TracebackType
from typing import Dict, List, Optional, Sequence, Type
from typing import Dict, List, Optional, Sequence, Type, TYPE_CHECKING

from ..context.pydantic_generator_context import PydanticGeneratorContext
from fern_python.codegen import AST, SourceFile
Expand All @@ -16,6 +16,9 @@

import fern.ir.resources as ir_types

if TYPE_CHECKING:
from fern_python.snippet.recursion_guard import RecursionGuard

TYPING_EXTENSIONS_MODULE = AST.Module.external(
module_path=("typing_extensions",),
dependency=TYPING_EXTENSIONS_DEPENDENCY,
Expand Down Expand Up @@ -203,7 +206,7 @@ def wrap_string_as_example(cls, string: str) -> ir_types.ExampleTypeReference:

@classmethod
def snippet_from_properties(
cls, example_properties: List[SimpleObjectProperty], snippet_writer: SnippetWriter
cls, example_properties: List[SimpleObjectProperty], snippet_writer: SnippetWriter, recursion_guard: Optional["RecursionGuard"] = None
) -> AST.Expression:
example_dict_pairs: List[ir_types.ExampleKeyValuePair] = []
for property in example_properties:
Expand All @@ -219,7 +222,7 @@ def snippet_from_properties(
)
)
return snippet_writer._get_snippet_for_map(
example_dict_pairs, use_typeddict_request=True, as_request=True, in_typeddict=True
example_dict_pairs, use_typeddict_request=True, as_request=True, in_typeddict=True, recursion_guard=recursion_guard
)

@classmethod
Expand All @@ -228,6 +231,7 @@ def type_to_snippet(
example: ir_types.ExampleObjectType,
snippet_writer: SnippetWriter,
additional_properties: List[SimpleObjectProperty] = [],
recursion_guard: Optional["RecursionGuard"] = None,
) -> AST.Expression:
example_properties = [
SimpleObjectProperty(
Expand All @@ -240,6 +244,7 @@ def type_to_snippet(
return cls.snippet_from_properties(
example_properties=example_properties,
snippet_writer=snippet_writer,
recursion_guard=recursion_guard,
)

@classmethod
Expand Down
50 changes: 50 additions & 0 deletions generators/python/src/fern_python/snippet/recursion_guard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Optional, Set
import fern.ir.resources as ir_types


class RecursionGuard:
"""
Guards against infinite recursion when generating examples for types with circular references.
Tracks visited types on the current recursion path and enforces a maximum depth limit.
"""

def __init__(self, max_depth: int = 5):
self._visited: Set[str] = set()
self._depth: int = 0
self._max_depth: int = max_depth

def _get_type_key(self, type_name: ir_types.DeclaredTypeName) -> str:
"""Generate a unique key for a type based on its package path and name."""
fern_filepath = ".".join(type_name.fern_filepath.package_path.parts) if type_name.fern_filepath.package_path.parts else ""
return f"{fern_filepath}:{type_name.name.original_name}"

def can_recurse(self, type_name: ir_types.DeclaredTypeName) -> bool:
"""
Check if we can safely recurse into the given type.
Returns False if the type is already on the recursion stack or if max depth is exceeded.
"""
if self._depth >= self._max_depth:
return False

type_key = self._get_type_key(type_name)
return type_key not in self._visited

def enter(self, type_name: ir_types.DeclaredTypeName) -> "RecursionGuard":
"""
Enter a new recursion level for the given type.
Returns a new RecursionGuard with the type added to the visited set.
"""
new_guard = RecursionGuard(max_depth=self._max_depth)
new_guard._visited = self._visited.copy()
new_guard._visited.add(self._get_type_key(type_name))
new_guard._depth = self._depth + 1
return new_guard

def with_depth_increment(self) -> "RecursionGuard":
"""
Increment depth without adding to visited set (for containers like lists/maps).
"""
new_guard = RecursionGuard(max_depth=self._max_depth)
new_guard._visited = self._visited.copy()
new_guard._depth = self._depth + 1
return new_guard
Loading
Loading