diff --git a/docs/api-reference/core/ImportType.mdx b/docs/api-reference/core/ImportType.mdx
new file mode 100644
index 000000000..857b1c60a
--- /dev/null
+++ b/docs/api-reference/core/ImportType.mdx
@@ -0,0 +1,44 @@
+---
+title: "ImportType"
+sidebarTitle: "ImportType"
+icon: ""
+description: "Import types for each import object. Determines what the import resolves to, and what symbols are imported."
+---
+import {Parameter} from '/snippets/Parameter.mdx';
+import {ParameterWrapper} from '/snippets/ParameterWrapper.mdx';
+import {Return} from '/snippets/Return.mdx';
+import {HorizontalDivider} from '/snippets/HorizontalDivider.mdx';
+import {GithubLinkNote} from '/snippets/GithubLinkNote.mdx';
+import {Attribute} from '/snippets/Attribute.mdx';
+
+
+
+
+## Attributes
+
+### DEFAULT_EXPORT
+
+ } description="Imports all default exports. Resolves to the file." />
+
+### MODULE
+
+ } description="Imports the module, not doesn't actually allow access to any of the exports" />
+
+### NAMED_EXPORT
+
+ } description="Imports a named export. Resolves to the symbol export." />
+
+### SIDE_EFFECT
+
+ } description="Imports the module, not doesn't actually allow access to any of the exports" />
+
+### UNKNOWN
+
+ } description="Unknown import type." />
+
+### WILDCARD
+
+ } description="Imports all named exports, and default exports as `default`. Resolves to the file." />
+
+
+
diff --git a/ruff.toml b/ruff.toml
index 338b939e4..82675c8a6 100644
--- a/ruff.toml
+++ b/ruff.toml
@@ -137,6 +137,7 @@ extend-generics = [
"codegen.sdk.core.symbol_groups.multi_line_collection.MultiLineCollection",
"codegen.sdk.core.symbol_groups.tuple.Tuple",
"codegen.sdk.core.type_alias.TypeAlias",
+ "codegen.sdk.enums.ImportType",
"codegen.sdk.python.statements.if_block_statement.PyIfBlockStatement",
"codegen.sdk.python.statements.with_statement.WithStatement",
"codegen.sdk.typescript.statements.block_statement.TSBlockStatement",
diff --git a/src/codegen/sdk/code_generation/doc_utils/generate_docs_json.py b/src/codegen/sdk/code_generation/doc_utils/generate_docs_json.py
index 12804fd50..8fdda80b8 100644
--- a/src/codegen/sdk/code_generation/doc_utils/generate_docs_json.py
+++ b/src/codegen/sdk/code_generation/doc_utils/generate_docs_json.py
@@ -1,10 +1,8 @@
-from typing import Any
-
from tqdm import tqdm
from codegen.sdk.code_generation.doc_utils.parse_docstring import parse_docstring
from codegen.sdk.code_generation.doc_utils.schemas import ClassDoc, GSDocs, MethodDoc
-from codegen.sdk.code_generation.doc_utils.utils import create_path, get_langauge, get_type, get_type_str, has_documentation, is_settter, replace_multiple_types
+from codegen.sdk.code_generation.doc_utils.utils import create_path, extract_class_description, get_langauge, get_type, get_type_str, has_documentation, is_settter, replace_multiple_types
from codegen.sdk.core.class_definition import Class
from codegen.sdk.core.codebase import Codebase
from codegen.sdk.core.placeholder.placeholder_type import TypePlaceholder
@@ -27,7 +25,7 @@
]
-def generate_docs_json(codebase: Codebase, head_commit: str, raise_on_missing_docstring: bool = False) -> dict[str, dict[str, Any]]:
+def generate_docs_json(codebase: Codebase, head_commit: str, raise_on_missing_docstring: bool = False) -> GSDocs:
"""Update documentation table for classes, methods and attributes in the codebase.
Args:
@@ -46,7 +44,7 @@ def process_class_doc(cls):
cls_doc = ClassDoc(
title=cls.name,
- description=description,
+ description=extract_class_description(description),
content=" ",
path=create_path(cls),
inherits_from=parent_classes,
diff --git a/src/codegen/sdk/code_generation/doc_utils/utils.py b/src/codegen/sdk/code_generation/doc_utils/utils.py
index d0b939de2..10e690e2d 100644
--- a/src/codegen/sdk/code_generation/doc_utils/utils.py
+++ b/src/codegen/sdk/code_generation/doc_utils/utils.py
@@ -14,6 +14,9 @@
logger = logging.getLogger(__name__)
+# These are the classes that are not language specific, but have language specific subclasses with different names
+SPECIAL_BASE_CLASSES = {"SourceFile": "File"}
+
def sanitize_docstring_for_markdown(docstring: str | None) -> str:
"""Sanitize the docstring for MDX"""
@@ -82,6 +85,9 @@ def is_language_base_class(cls_obj: Class):
Returns:
bool: if `cls_obj` is a language base class
"""
+ if cls_obj.name in SPECIAL_BASE_CLASSES:
+ return True
+
sub_classes = cls_obj.subclasses(max_depth=1)
base_name = cls_obj.name.lower()
return any(sub_class.name.lower() in [f"py{base_name}", f"ts{base_name}"] for sub_class in sub_classes)
@@ -184,24 +190,53 @@ def has_documentation(c: Class):
return any([dec.name == "ts_apidoc" or dec.name == "py_apidoc" or dec.name == "apidoc" for dec in c.decorators])
-def safe_get_class(codebase: Codebase, class_name: str) -> Class | None:
- symbols = codebase.get_symbols(class_name)
- if not symbols:
- return None
-
- if len(symbols) == 1 and isinstance(symbols[0], Class):
- return symbols[0]
-
- possible_classes = [s for s in symbols if isinstance(s, Class) and has_documentation(s)]
- if not possible_classes:
- return None
- if len(possible_classes) == 1:
- return possible_classes[0]
- msg = f"Found {len(possible_classes)} classes with name {class_name}"
- raise ValueError(msg)
+def safe_get_class(codebase: Codebase, class_name: str, language: str | None = None) -> Class | None:
+ """Find the class in the codebase.
-
-def find_symbol(codebase: Codebase, symbol_name: str, resolved_types: list[Type], parent_class: Class, parent_symbol: Symbol, types_cache: dict):
+ Args:
+ codebase (Codebase): the codebase to search in
+ class_name (str): the name of the class to resolve
+ language (str | None): the language of the class to resolve
+ Returns:
+ Class | None: the class if found, None otherwise
+ """
+ if '"' in class_name:
+ class_name = class_name.strip('"')
+ if "'" in class_name:
+ class_name = class_name.strip("'")
+
+ symbols = []
+ try:
+ class_obj = codebase.get_class(class_name, optional=True)
+ if not class_obj:
+ return None
+
+ except Exception:
+ symbols = codebase.get_symbols(class_name)
+ possible_classes = [s for s in symbols if isinstance(s, Class) and has_documentation(s)]
+ if not possible_classes:
+ return None
+ if len(possible_classes) > 1:
+ msg = f"Found {len(possible_classes)} classes with name {class_name}"
+ raise ValueError(msg)
+ class_obj = possible_classes[0]
+
+ if language and is_language_base_class(class_obj):
+ sub_classes = class_obj.subclasses(max_depth=1)
+
+ if class_name in SPECIAL_BASE_CLASSES:
+ class_name = SPECIAL_BASE_CLASSES[class_name]
+
+ if language == ProgrammingLanguage.PYTHON.value:
+ sub_classes = [s for s in sub_classes if s.name == f"Py{class_name}"]
+ elif language == ProgrammingLanguage.TYPESCRIPT.value:
+ sub_classes = [s for s in sub_classes if s.name == f"TS{class_name}"]
+ if len(sub_classes) == 1:
+ class_obj = sub_classes[0]
+ return class_obj
+
+
+def resolve_type_symbol(codebase: Codebase, symbol_name: str, resolved_types: list[Type], parent_class: Class, parent_symbol: Symbol, types_cache: dict):
"""Find the symbol in the codebase.
Args:
@@ -217,11 +252,13 @@ def find_symbol(codebase: Codebase, symbol_name: str, resolved_types: list[Type]
return symbol_name
if symbol_name.lower() == "self":
return f"<{create_path(parent_class)}>"
- if symbol_name in types_cache:
- return types_cache[symbol_name]
+
+ language = get_langauge(parent_class)
+ if (symbol_name, language) in types_cache:
+ return types_cache[(symbol_name, language)]
trgt_symbol = None
- cls_obj = safe_get_class(codebase, symbol_name)
+ cls_obj = safe_get_class(codebase=codebase, class_name=symbol_name, language=language)
if cls_obj:
trgt_symbol = cls_obj
@@ -230,8 +267,8 @@ def find_symbol(codebase: Codebase, symbol_name: str, resolved_types: list[Type]
for resolved_type in symbol.resolved_types:
if isinstance(resolved_type, FunctionCall) and len(resolved_type.args) >= 2:
bound_arg = resolved_type.args[1]
- bound_name = bound_arg.value
- if cls_obj := safe_get_class(codebase, bound_name):
+ bound_name = bound_arg.value.source
+ if cls_obj := safe_get_class(codebase, bound_name, language=get_langauge(parent_class)):
trgt_symbol = cls_obj
break
@@ -241,7 +278,7 @@ def find_symbol(codebase: Codebase, symbol_name: str, resolved_types: list[Type]
if trgt_symbol and isinstance(trgt_symbol, Callable) and has_documentation(trgt_symbol):
trgt_path = f"<{create_path(trgt_symbol)}>"
- types_cache[symbol_name] = trgt_path
+ types_cache[(symbol_name, language)] = trgt_path
return trgt_path
return symbol_name
@@ -318,10 +355,12 @@ def process_parts(content):
base_type = part[: part.index("[")]
bracket_content = part[part.index("[") :].strip("[]")
processed_bracket = process_parts(bracket_content)
- replacement = find_symbol(codebase=codebase, symbol_name=base_type, resolved_types=resolved_types, parent_class=parent_class, parent_symbol=parent_symbol, types_cache=types_cache)
+ replacement = resolve_type_symbol(
+ codebase=codebase, symbol_name=base_type, resolved_types=resolved_types, parent_class=parent_class, parent_symbol=parent_symbol, types_cache=types_cache
+ )
processed_part = replacement + "[" + processed_bracket + "]"
else:
- replacement = find_symbol(codebase=codebase, symbol_name=part, resolved_types=resolved_types, parent_class=parent_class, parent_symbol=parent_symbol, types_cache=types_cache)
+ replacement = resolve_type_symbol(codebase=codebase, symbol_name=part, resolved_types=resolved_types, parent_class=parent_class, parent_symbol=parent_symbol, types_cache=types_cache)
processed_part = replacement
processed_parts.append(processed_part)
@@ -340,9 +379,30 @@ def process_parts(content):
base_type = input_str[: input_str.index("[")]
bracket_content = input_str[input_str.index("[") :].strip("[]")
processed_content = process_parts(bracket_content)
- replacement = find_symbol(codebase=codebase, symbol_name=base_type, resolved_types=resolved_types, parent_class=parent_class, parent_symbol=parent_symbol, types_cache=types_cache)
+ replacement = resolve_type_symbol(codebase=codebase, symbol_name=base_type, resolved_types=resolved_types, parent_class=parent_class, parent_symbol=parent_symbol, types_cache=types_cache)
return replacement + "[" + processed_content + "]"
# Handle simple input
else:
- replacement = find_symbol(codebase=codebase, symbol_name=input_str, resolved_types=resolved_types, parent_class=parent_class, parent_symbol=parent_symbol, types_cache=types_cache)
+ replacement = resolve_type_symbol(codebase=codebase, symbol_name=input_str, resolved_types=resolved_types, parent_class=parent_class, parent_symbol=parent_symbol, types_cache=types_cache)
return replacement
+
+
+def extract_class_description(docstring):
+ """Extract the class description from a docstring, excluding the attributes section.
+
+ Args:
+ docstring (str): The class docstring to parse
+
+ Returns:
+ str: The class description with whitespace normalized
+ """
+ if not docstring:
+ return ""
+
+ # Split by "Attributes:" and take only the first part
+ parts = docstring.split("Attributes:")
+ description = parts[0]
+
+ # Normalize whitespace
+ lines = [line.strip() for line in description.strip().splitlines()]
+ return " ".join(filter(None, lines))
diff --git a/src/codegen/sdk/enums.py b/src/codegen/sdk/enums.py
index 09d07344a..c2fe533af 100644
--- a/src/codegen/sdk/enums.py
+++ b/src/codegen/sdk/enums.py
@@ -2,6 +2,7 @@
from typing import NamedTuple
from codegen.sdk.core.dataclasses.usage import Usage
+from codegen.shared.decorators.docs import apidoc
class NodeType(IntEnum):
@@ -51,8 +52,18 @@ class SymbolType(IntEnum):
Namespace = auto()
+@apidoc
class ImportType(IntEnum):
- """Import types for each import object. Determines what the import resolves to, and what symbols are imported."""
+ """Import types for each import object. Determines what the import resolves to, and what symbols are imported.
+
+ Attributes:
+ DEFAULT_EXPORT: Imports all default exports. Resolves to the file.
+ NAMED_EXPORT: Imports a named export. Resolves to the symbol export.
+ WILDCARD: Imports all named exports, and default exports as `default`. Resolves to the file.
+ MODULE: Imports the module, not doesn't actually allow access to any of the exports
+ SIDE_EFFECT: Imports the module, not doesn't actually allow access to any of the exports
+ UNKNOWN: Unknown import type.
+ """
# Imports all default exports. Resolves to the file.
DEFAULT_EXPORT = auto()
diff --git a/src/codegen/shared/compilation/function_imports.py b/src/codegen/shared/compilation/function_imports.py
index 67ce6f0b3..4b01484b0 100644
--- a/src/codegen/shared/compilation/function_imports.py
+++ b/src/codegen/shared/compilation/function_imports.py
@@ -108,6 +108,7 @@ def get_generated_imports():
from codegen.sdk.core.symbol_groups.multi_line_collection import MultiLineCollection
from codegen.sdk.core.symbol_groups.tuple import Tuple
from codegen.sdk.core.type_alias import TypeAlias
+from codegen.sdk.enums import ImportType
from codegen.sdk.python.assignment import PyAssignment
from codegen.sdk.python.class_definition import PyClass
from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock