Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/codegen/sdk/codebase/node_classes/node_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from codegen.sdk.core.file import SourceFile
from codegen.sdk.core.function import Function
from codegen.sdk.core.import_resolution import Import
from codegen.sdk.core.interfaces.editable import Editable
from codegen.sdk.core.statements.comment import Comment
from codegen.sdk.core.symbol import Symbol

Expand All @@ -33,7 +34,7 @@ class NodeClasses:
function_call_cls: type[FunctionCall]
comment_cls: type[Comment]
bool_conversion: dict[bool, str]
dynamic_import_parent_types: set[str]
dynamic_import_parent_types: set[type[Editable]]
symbol_map: dict[str, type[Symbol]] = field(default_factory=dict)
expression_map: dict[str, type[Expression]] = field(default_factory=dict)
type_map: dict[str, type[Type] | dict[str, type[Type]]] = field(default_factory=dict)
Expand Down
27 changes: 16 additions & 11 deletions src/codegen/sdk/codebase/node_classes/py_node_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
from codegen.sdk.core.expressions.subscript_expression import SubscriptExpression
from codegen.sdk.core.expressions.unary_expression import UnaryExpression
from codegen.sdk.core.expressions.unpack import Unpack
from codegen.sdk.core.function import Function
from codegen.sdk.core.statements.comment import Comment
from codegen.sdk.core.statements.for_loop_statement import ForLoopStatement
from codegen.sdk.core.statements.if_block_statement import IfBlockStatement
from codegen.sdk.core.statements.switch_statement import SwitchStatement
from codegen.sdk.core.statements.try_catch_statement import TryCatchStatement
from codegen.sdk.core.statements.while_statement import WhileStatement
from codegen.sdk.core.symbol_groups.dict import Dict
from codegen.sdk.core.symbol_groups.list import List
from codegen.sdk.core.symbol_groups.tuple import Tuple
Expand All @@ -29,6 +35,8 @@
from codegen.sdk.python.expressions.string import PyString
from codegen.sdk.python.expressions.union_type import PyUnionType
from codegen.sdk.python.statements.import_statement import PyImportStatement
from codegen.sdk.python.statements.match_case import PyMatchCase
from codegen.sdk.python.statements.with_statement import WithStatement


def parse_subscript(node: TSNode, file_node_id, ctx, parent):
Expand Down Expand Up @@ -110,16 +118,13 @@ def parse_subscript(node: TSNode, file_node_id, ctx, parent):
False: "False",
},
dynamic_import_parent_types={
"function_definition",
"if_statement",
"try_statement",
"with_statement",
"else_clause",
"for_statement",
"except_clause",
"while_statement",
"match_statement",
"case_clause",
"finally_clause",
Function,
IfBlockStatement,
TryCatchStatement,
WithStatement,
ForLoopStatement,
WhileStatement,
SwitchStatement,
PyMatchCase,
},
)
27 changes: 14 additions & 13 deletions src/codegen/sdk/codebase/node_classes/ts_node_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@
from codegen.sdk.core.expressions.unary_expression import UnaryExpression
from codegen.sdk.core.expressions.unpack import Unpack
from codegen.sdk.core.expressions.value import Value
from codegen.sdk.core.function import Function
from codegen.sdk.core.statements.comment import Comment
from codegen.sdk.core.statements.for_loop_statement import ForLoopStatement
from codegen.sdk.core.statements.if_block_statement import IfBlockStatement
from codegen.sdk.core.statements.switch_case import SwitchCase
from codegen.sdk.core.statements.switch_statement import SwitchStatement
from codegen.sdk.core.statements.try_catch_statement import TryCatchStatement
from codegen.sdk.core.statements.while_statement import WhileStatement
from codegen.sdk.core.symbol_groups.list import List
from codegen.sdk.core.symbol_groups.type_parameters import TypeParameters
from codegen.sdk.typescript.class_definition import TSClass
Expand Down Expand Up @@ -166,18 +173,12 @@ def parse_new(node: TSNode, *args):
False: "false",
},
dynamic_import_parent_types={
"function_declaration",
"method_definition",
"arrow_function",
"if_statement",
"try_statement",
"else_clause",
"catch_clause",
"finally_clause",
"while_statement",
"for_statement",
"do_statement",
"switch_case",
"switch_statement",
Function,
IfBlockStatement,
TryCatchStatement,
ForLoopStatement,
WhileStatement,
SwitchStatement,
SwitchCase,
},
)
108 changes: 108 additions & 0 deletions src/codegen/sdk/core/codebase.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import os
import re
import tempfile
from collections.abc import Generator
from contextlib import contextmanager
from functools import cached_property
Expand Down Expand Up @@ -1298,6 +1299,113 @@ def from_repo(
logger.exception(f"Failed to initialize codebase: {e}")
raise

@classmethod
def from_string(
cls,
code: str,
*,
language: Literal["python", "typescript"] | ProgrammingLanguage,
) -> "Codebase":
"""Creates a Codebase instance from a string of code.

Args:
code (str): The source code string
language (Literal["python", "typescript"] | ProgrammingLanguage): The programming language of the code.

Returns:
Codebase: A Codebase instance initialized with the provided code
"""
logger.info("Creating codebase from string")

# Determine language and filename
prog_lang = ProgrammingLanguage(language.upper()) if isinstance(language, str) else language
filename = "test.ts" if prog_lang == ProgrammingLanguage.TYPESCRIPT else "test.py"

with tempfile.TemporaryDirectory(prefix="codegen_") as tmp_dir:
logger.info(f"Using directory: {tmp_dir}")

# Create codebase using factory
from codegen.sdk.codebase.factory.codebase_factory import CodebaseFactory

files = {filename: code}
codebase = CodebaseFactory.get_codebase_from_files(repo_path=tmp_dir, files=files, programming_language=prog_lang)
logger.info("Codebase initialization complete")
return codebase

@classmethod
def from_files(
cls,
files: dict[str, str],
*,
language: Literal["python", "typescript"] | ProgrammingLanguage | None = None,
) -> "Codebase":
"""Creates a Codebase instance from multiple files.

Args:
files: Dictionary mapping filenames to their content, e.g. {"main.py": "print('hello')"}
language: Optional language override. If not provided, will be inferred from file extensions.
All files must have extensions matching the same language.

Returns:
Codebase: A Codebase instance initialized with the provided files

Raises:
ValueError: If file extensions don't match a single language or if explicitly provided
language doesn't match the extensions

Example:
>>> # Language inferred as Python
>>> files = {"main.py": "print('hello')", "utils.py": "def add(a, b): return a + b"}
>>> codebase = Codebase.from_files(files)

>>> # Language inferred as TypeScript
>>> files = {"index.ts": "console.log('hello')", "utils.tsx": "export const App = () => <div>Hello</div>"}
>>> codebase = Codebase.from_files(files)
"""
logger.info("Creating codebase from files")

if not files:
# Default to Python if no files provided
prog_lang = ProgrammingLanguage.PYTHON if language is None else (ProgrammingLanguage(language.upper()) if isinstance(language, str) else language)
logger.info(f"No files provided, using {prog_lang}")
else:
# Map extensions to languages
py_extensions = {".py"}
ts_extensions = {".ts", ".tsx", ".js", ".jsx"}

# Get unique extensions from files
extensions = {os.path.splitext(f)[1].lower() for f in files}

# Determine language from extensions
inferred_lang = None
if all(ext in py_extensions for ext in extensions):
inferred_lang = ProgrammingLanguage.PYTHON
elif all(ext in ts_extensions for ext in extensions):
inferred_lang = ProgrammingLanguage.TYPESCRIPT
else:
msg = f"Cannot determine single language from extensions: {extensions}. Files must all be Python (.py) or TypeScript (.ts, .tsx, .js, .jsx)"
raise ValueError(msg)

# If language was explicitly provided, verify it matches inferred language
if language is not None:
explicit_lang = ProgrammingLanguage(language.upper()) if isinstance(language, str) else language
if explicit_lang != inferred_lang:
msg = f"Provided language {explicit_lang} doesn't match inferred language {inferred_lang} from file extensions"
raise ValueError(msg)

prog_lang = inferred_lang
logger.info(f"Using language: {prog_lang} ({'inferred' if language is None else 'explicit'})")

with tempfile.TemporaryDirectory(prefix="codegen_") as tmp_dir:
logger.info(f"Using directory: {tmp_dir}")

# Create codebase using factory
from codegen.sdk.codebase.factory.codebase_factory import CodebaseFactory

codebase = CodebaseFactory.get_codebase_from_files(repo_path=tmp_dir, files=files, programming_language=prog_lang)
logger.info("Codebase initialization complete")
return codebase

def get_modified_symbols_in_pr(self, pr_id: int) -> tuple[str, dict[str, str], list[str]]:
"""Get all modified symbols in a pull request"""
pr = self._op.get_pull_request(pr_id)
Expand Down
5 changes: 0 additions & 5 deletions src/codegen/sdk/core/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,15 +459,10 @@ def parse(self, ctx: CodebaseContext) -> None:
self.code_block = self._parse_code_block(self.ts_node)

self.code_block.parse()
self._parse_imports()
# We need to clear the valid symbol/import names before we start resolving exports since these can be outdated.
self.invalidate()
sort_editables(self._nodes)

@abstractmethod
@commiter
def _parse_imports(self) -> None: ...

@noapidoc
@commiter
def remove_internal_edges(self) -> None:
Expand Down
10 changes: 1 addition & 9 deletions src/codegen/sdk/core/import_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,15 +428,7 @@ def my_function():
bool: True if the import is dynamic (within a control flow or scope block),
False if it's a top-level import.
"""
curr = self.ts_node

# always traverses upto the module level
while curr:
if curr.type in self.ctx.node_classes.dynamic_import_parent_types:
return True
curr = curr.parent

return False
return self.parent_of_types(self.ctx.node_classes.dynamic_import_parent_types) is not None

####################################################################################################################
# MANIPULATIONS
Expand Down
10 changes: 9 additions & 1 deletion src/codegen/sdk/core/interfaces/editable.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,7 @@ def children_by_field_types(self, field_types: str | Iterable[str]) -> Generator
@reader
@noapidoc
def child_by_field_types(self, field_types: str | Iterable[str]) -> Expression[Self] | None:
"""Get child by field types."""
"""Get child by fiexld types."""
return next(self.children_by_field_types(field_types), None)

@property
Expand Down Expand Up @@ -1097,6 +1097,14 @@ def parent_of_type(self, type: type[T]) -> T | None:
return self.parent.parent_of_type(type)
return None

def parent_of_types(self, types: set[type[T]]) -> T | None:
"""Find the first ancestor of the node of the given type. Does not return itself"""
if self.parent and any(isinstance(self.parent, t) for t in types):
return self.parent
if self.parent is not self and self.parent is not None:
return self.parent.parent_of_types(types)
return None

@reader
def ancestors(self, type: type[T]) -> list[T]:
"""Find all ancestors of the node of the given type. Does not return itself"""
Expand Down
Loading
Loading