diff --git a/docs/api-reference/core/ImportType.mdx b/docs/api-reference/core/ImportType.mdx new file mode 100644 index 000000000..857b1c60a --- /dev/null +++ b/docs/api-reference/core/ImportType.mdx @@ -0,0 +1,44 @@ +--- +title: "ImportType" +sidebarTitle: "ImportType" +icon: "" +description: "Import types for each import object. Determines what the import resolves to, and what symbols are imported." +--- +import {Parameter} from '/snippets/Parameter.mdx'; +import {ParameterWrapper} from '/snippets/ParameterWrapper.mdx'; +import {Return} from '/snippets/Return.mdx'; +import {HorizontalDivider} from '/snippets/HorizontalDivider.mdx'; +import {GithubLinkNote} from '/snippets/GithubLinkNote.mdx'; +import {Attribute} from '/snippets/Attribute.mdx'; + + + + +## Attributes + +### DEFAULT_EXPORT + + } description="Imports all default exports. Resolves to the file." /> + +### MODULE + + } description="Imports the module, not doesn't actually allow access to any of the exports" /> + +### NAMED_EXPORT + + } description="Imports a named export. Resolves to the symbol export." /> + +### SIDE_EFFECT + + } description="Imports the module, not doesn't actually allow access to any of the exports" /> + +### UNKNOWN + + } description="Unknown import type." /> + +### WILDCARD + + } description="Imports all named exports, and default exports as `default`. Resolves to the file." /> + + + diff --git a/ruff.toml b/ruff.toml index 338b939e4..82675c8a6 100644 --- a/ruff.toml +++ b/ruff.toml @@ -137,6 +137,7 @@ extend-generics = [ "codegen.sdk.core.symbol_groups.multi_line_collection.MultiLineCollection", "codegen.sdk.core.symbol_groups.tuple.Tuple", "codegen.sdk.core.type_alias.TypeAlias", + "codegen.sdk.enums.ImportType", "codegen.sdk.python.statements.if_block_statement.PyIfBlockStatement", "codegen.sdk.python.statements.with_statement.WithStatement", "codegen.sdk.typescript.statements.block_statement.TSBlockStatement", diff --git a/src/codegen/sdk/code_generation/doc_utils/generate_docs_json.py b/src/codegen/sdk/code_generation/doc_utils/generate_docs_json.py index 12804fd50..8fdda80b8 100644 --- a/src/codegen/sdk/code_generation/doc_utils/generate_docs_json.py +++ b/src/codegen/sdk/code_generation/doc_utils/generate_docs_json.py @@ -1,10 +1,8 @@ -from typing import Any - from tqdm import tqdm from codegen.sdk.code_generation.doc_utils.parse_docstring import parse_docstring from codegen.sdk.code_generation.doc_utils.schemas import ClassDoc, GSDocs, MethodDoc -from codegen.sdk.code_generation.doc_utils.utils import create_path, get_langauge, get_type, get_type_str, has_documentation, is_settter, replace_multiple_types +from codegen.sdk.code_generation.doc_utils.utils import create_path, extract_class_description, get_langauge, get_type, get_type_str, has_documentation, is_settter, replace_multiple_types from codegen.sdk.core.class_definition import Class from codegen.sdk.core.codebase import Codebase from codegen.sdk.core.placeholder.placeholder_type import TypePlaceholder @@ -27,7 +25,7 @@ ] -def generate_docs_json(codebase: Codebase, head_commit: str, raise_on_missing_docstring: bool = False) -> dict[str, dict[str, Any]]: +def generate_docs_json(codebase: Codebase, head_commit: str, raise_on_missing_docstring: bool = False) -> GSDocs: """Update documentation table for classes, methods and attributes in the codebase. Args: @@ -46,7 +44,7 @@ def process_class_doc(cls): cls_doc = ClassDoc( title=cls.name, - description=description, + description=extract_class_description(description), content=" ", path=create_path(cls), inherits_from=parent_classes, diff --git a/src/codegen/sdk/code_generation/doc_utils/utils.py b/src/codegen/sdk/code_generation/doc_utils/utils.py index d0b939de2..10e690e2d 100644 --- a/src/codegen/sdk/code_generation/doc_utils/utils.py +++ b/src/codegen/sdk/code_generation/doc_utils/utils.py @@ -14,6 +14,9 @@ logger = logging.getLogger(__name__) +# These are the classes that are not language specific, but have language specific subclasses with different names +SPECIAL_BASE_CLASSES = {"SourceFile": "File"} + def sanitize_docstring_for_markdown(docstring: str | None) -> str: """Sanitize the docstring for MDX""" @@ -82,6 +85,9 @@ def is_language_base_class(cls_obj: Class): Returns: bool: if `cls_obj` is a language base class """ + if cls_obj.name in SPECIAL_BASE_CLASSES: + return True + sub_classes = cls_obj.subclasses(max_depth=1) base_name = cls_obj.name.lower() return any(sub_class.name.lower() in [f"py{base_name}", f"ts{base_name}"] for sub_class in sub_classes) @@ -184,24 +190,53 @@ def has_documentation(c: Class): return any([dec.name == "ts_apidoc" or dec.name == "py_apidoc" or dec.name == "apidoc" for dec in c.decorators]) -def safe_get_class(codebase: Codebase, class_name: str) -> Class | None: - symbols = codebase.get_symbols(class_name) - if not symbols: - return None - - if len(symbols) == 1 and isinstance(symbols[0], Class): - return symbols[0] - - possible_classes = [s for s in symbols if isinstance(s, Class) and has_documentation(s)] - if not possible_classes: - return None - if len(possible_classes) == 1: - return possible_classes[0] - msg = f"Found {len(possible_classes)} classes with name {class_name}" - raise ValueError(msg) +def safe_get_class(codebase: Codebase, class_name: str, language: str | None = None) -> Class | None: + """Find the class in the codebase. - -def find_symbol(codebase: Codebase, symbol_name: str, resolved_types: list[Type], parent_class: Class, parent_symbol: Symbol, types_cache: dict): + Args: + codebase (Codebase): the codebase to search in + class_name (str): the name of the class to resolve + language (str | None): the language of the class to resolve + Returns: + Class | None: the class if found, None otherwise + """ + if '"' in class_name: + class_name = class_name.strip('"') + if "'" in class_name: + class_name = class_name.strip("'") + + symbols = [] + try: + class_obj = codebase.get_class(class_name, optional=True) + if not class_obj: + return None + + except Exception: + symbols = codebase.get_symbols(class_name) + possible_classes = [s for s in symbols if isinstance(s, Class) and has_documentation(s)] + if not possible_classes: + return None + if len(possible_classes) > 1: + msg = f"Found {len(possible_classes)} classes with name {class_name}" + raise ValueError(msg) + class_obj = possible_classes[0] + + if language and is_language_base_class(class_obj): + sub_classes = class_obj.subclasses(max_depth=1) + + if class_name in SPECIAL_BASE_CLASSES: + class_name = SPECIAL_BASE_CLASSES[class_name] + + if language == ProgrammingLanguage.PYTHON.value: + sub_classes = [s for s in sub_classes if s.name == f"Py{class_name}"] + elif language == ProgrammingLanguage.TYPESCRIPT.value: + sub_classes = [s for s in sub_classes if s.name == f"TS{class_name}"] + if len(sub_classes) == 1: + class_obj = sub_classes[0] + return class_obj + + +def resolve_type_symbol(codebase: Codebase, symbol_name: str, resolved_types: list[Type], parent_class: Class, parent_symbol: Symbol, types_cache: dict): """Find the symbol in the codebase. Args: @@ -217,11 +252,13 @@ def find_symbol(codebase: Codebase, symbol_name: str, resolved_types: list[Type] return symbol_name if symbol_name.lower() == "self": return f"<{create_path(parent_class)}>" - if symbol_name in types_cache: - return types_cache[symbol_name] + + language = get_langauge(parent_class) + if (symbol_name, language) in types_cache: + return types_cache[(symbol_name, language)] trgt_symbol = None - cls_obj = safe_get_class(codebase, symbol_name) + cls_obj = safe_get_class(codebase=codebase, class_name=symbol_name, language=language) if cls_obj: trgt_symbol = cls_obj @@ -230,8 +267,8 @@ def find_symbol(codebase: Codebase, symbol_name: str, resolved_types: list[Type] for resolved_type in symbol.resolved_types: if isinstance(resolved_type, FunctionCall) and len(resolved_type.args) >= 2: bound_arg = resolved_type.args[1] - bound_name = bound_arg.value - if cls_obj := safe_get_class(codebase, bound_name): + bound_name = bound_arg.value.source + if cls_obj := safe_get_class(codebase, bound_name, language=get_langauge(parent_class)): trgt_symbol = cls_obj break @@ -241,7 +278,7 @@ def find_symbol(codebase: Codebase, symbol_name: str, resolved_types: list[Type] if trgt_symbol and isinstance(trgt_symbol, Callable) and has_documentation(trgt_symbol): trgt_path = f"<{create_path(trgt_symbol)}>" - types_cache[symbol_name] = trgt_path + types_cache[(symbol_name, language)] = trgt_path return trgt_path return symbol_name @@ -318,10 +355,12 @@ def process_parts(content): base_type = part[: part.index("[")] bracket_content = part[part.index("[") :].strip("[]") processed_bracket = process_parts(bracket_content) - replacement = find_symbol(codebase=codebase, symbol_name=base_type, resolved_types=resolved_types, parent_class=parent_class, parent_symbol=parent_symbol, types_cache=types_cache) + replacement = resolve_type_symbol( + codebase=codebase, symbol_name=base_type, resolved_types=resolved_types, parent_class=parent_class, parent_symbol=parent_symbol, types_cache=types_cache + ) processed_part = replacement + "[" + processed_bracket + "]" else: - replacement = find_symbol(codebase=codebase, symbol_name=part, resolved_types=resolved_types, parent_class=parent_class, parent_symbol=parent_symbol, types_cache=types_cache) + replacement = resolve_type_symbol(codebase=codebase, symbol_name=part, resolved_types=resolved_types, parent_class=parent_class, parent_symbol=parent_symbol, types_cache=types_cache) processed_part = replacement processed_parts.append(processed_part) @@ -340,9 +379,30 @@ def process_parts(content): base_type = input_str[: input_str.index("[")] bracket_content = input_str[input_str.index("[") :].strip("[]") processed_content = process_parts(bracket_content) - replacement = find_symbol(codebase=codebase, symbol_name=base_type, resolved_types=resolved_types, parent_class=parent_class, parent_symbol=parent_symbol, types_cache=types_cache) + replacement = resolve_type_symbol(codebase=codebase, symbol_name=base_type, resolved_types=resolved_types, parent_class=parent_class, parent_symbol=parent_symbol, types_cache=types_cache) return replacement + "[" + processed_content + "]" # Handle simple input else: - replacement = find_symbol(codebase=codebase, symbol_name=input_str, resolved_types=resolved_types, parent_class=parent_class, parent_symbol=parent_symbol, types_cache=types_cache) + replacement = resolve_type_symbol(codebase=codebase, symbol_name=input_str, resolved_types=resolved_types, parent_class=parent_class, parent_symbol=parent_symbol, types_cache=types_cache) return replacement + + +def extract_class_description(docstring): + """Extract the class description from a docstring, excluding the attributes section. + + Args: + docstring (str): The class docstring to parse + + Returns: + str: The class description with whitespace normalized + """ + if not docstring: + return "" + + # Split by "Attributes:" and take only the first part + parts = docstring.split("Attributes:") + description = parts[0] + + # Normalize whitespace + lines = [line.strip() for line in description.strip().splitlines()] + return " ".join(filter(None, lines)) diff --git a/src/codegen/sdk/enums.py b/src/codegen/sdk/enums.py index 09d07344a..c2fe533af 100644 --- a/src/codegen/sdk/enums.py +++ b/src/codegen/sdk/enums.py @@ -2,6 +2,7 @@ from typing import NamedTuple from codegen.sdk.core.dataclasses.usage import Usage +from codegen.shared.decorators.docs import apidoc class NodeType(IntEnum): @@ -51,8 +52,18 @@ class SymbolType(IntEnum): Namespace = auto() +@apidoc class ImportType(IntEnum): - """Import types for each import object. Determines what the import resolves to, and what symbols are imported.""" + """Import types for each import object. Determines what the import resolves to, and what symbols are imported. + + Attributes: + DEFAULT_EXPORT: Imports all default exports. Resolves to the file. + NAMED_EXPORT: Imports a named export. Resolves to the symbol export. + WILDCARD: Imports all named exports, and default exports as `default`. Resolves to the file. + MODULE: Imports the module, not doesn't actually allow access to any of the exports + SIDE_EFFECT: Imports the module, not doesn't actually allow access to any of the exports + UNKNOWN: Unknown import type. + """ # Imports all default exports. Resolves to the file. DEFAULT_EXPORT = auto() diff --git a/src/codegen/shared/compilation/function_imports.py b/src/codegen/shared/compilation/function_imports.py index 67ce6f0b3..4b01484b0 100644 --- a/src/codegen/shared/compilation/function_imports.py +++ b/src/codegen/shared/compilation/function_imports.py @@ -108,6 +108,7 @@ def get_generated_imports(): from codegen.sdk.core.symbol_groups.multi_line_collection import MultiLineCollection from codegen.sdk.core.symbol_groups.tuple import Tuple from codegen.sdk.core.type_alias import TypeAlias +from codegen.sdk.enums import ImportType from codegen.sdk.python.assignment import PyAssignment from codegen.sdk.python.class_definition import PyClass from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock