Skip to content

Commit 0d2fc93

Browse files
authored
Enhance SQLAlchemy Type Annotations Example (#12)
* . * readme * done * new format * .
1 parent cde9cd5 commit 0d2fc93

File tree

9 files changed

+385
-0
lines changed

9 files changed

+385
-0
lines changed
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Enhance SQLAlchemy Type Annotations
2+
3+
This codemod demonstrates how to automatically add type annotations to SQLAlchemy models in your Python codebase. The migration script makes this process simple by handling all the tedious manual updates automatically.
4+
5+
## How the Migration Script Works
6+
7+
The script automates the entire migration process in a few key steps:
8+
9+
1. **Model Detection and Analysis**
10+
```python
11+
codebase = Codebase.from_repo("your/repo")
12+
for file in codebase.files:
13+
if "models" not in file.filepath:
14+
continue
15+
```
16+
- Automatically identifies SQLAlchemy model files
17+
- Analyzes model structure and relationships
18+
- Determines required type annotations
19+
20+
2. **Type Annotation Updates**
21+
```python
22+
for column in model.columns:
23+
if isinstance(column, Column):
24+
column.edit(to_mapped_column(column))
25+
```
26+
- Converts Column definitions to typed Mapped columns
27+
- Handles nullable fields with Optional types
28+
- Preserves existing column configurations
29+
30+
3. **Relationship Transformations**
31+
```python
32+
for rel in model.relationships:
33+
if isinstance(rel, relationship):
34+
rel.edit(to_typed_relationship(rel))
35+
```
36+
- Updates relationship definitions with proper typing
37+
- Converts backref to back_populates
38+
- Adds List/Optional type wrappers as needed
39+
40+
## Common Migration Patterns
41+
42+
### Column Definitions
43+
```python
44+
# Before
45+
id = Column(Integer, primary_key=True)
46+
name = Column(String)
47+
48+
# After
49+
id: Mapped[int] = mapped_column(primary_key=True)
50+
name: Mapped[str] = mapped_column()
51+
```
52+
53+
### Nullable Fields
54+
```python
55+
# Before
56+
description = Column(String, nullable=True)
57+
58+
# After
59+
description: Mapped[Optional[str]] = mapped_column(nullable=True)
60+
```
61+
62+
### Relationships
63+
```python
64+
# Before
65+
addresses = relationship("Address", backref="user")
66+
67+
# After
68+
addresses: Mapped[List["Address"]] = relationship(back_populates="user")
69+
```
70+
71+
## Complete Example
72+
73+
### Before Migration
74+
```python
75+
from sqlalchemy import Column, Integer, String, ForeignKey
76+
from sqlalchemy.orm import relationship, backref
77+
from database import Base
78+
79+
class Publisher(Base):
80+
__tablename__ = "publishers"
81+
82+
id = Column(Integer, primary_key=True, index=True)
83+
name = Column(String, unique=True, index=True)
84+
books = relationship("Book", backref="publisher")
85+
86+
class Book(Base):
87+
__tablename__ = "books"
88+
89+
id = Column(Integer, primary_key=True, index=True)
90+
title = Column(String, index=True)
91+
author = Column(String, index=True)
92+
description = Column(String)
93+
publisher_id = Column(Integer, ForeignKey("publishers.id"))
94+
```
95+
96+
### After Migration
97+
```python
98+
from typing import List, Optional
99+
from sqlalchemy import ForeignKey
100+
from sqlalchemy.orm import Mapped, mapped_column, relationship
101+
from database import Base
102+
103+
class Publisher(Base):
104+
__tablename__ = "publishers"
105+
106+
id: Mapped[int] = mapped_column(primary_key=True, index=True)
107+
name: Mapped[str] = mapped_column(unique=True, index=True)
108+
books: Mapped[List["Book"]] = relationship(
109+
"Book",
110+
back_populates="publisher"
111+
)
112+
113+
class Book(Base):
114+
__tablename__ = "books"
115+
116+
id: Mapped[int] = mapped_column(primary_key=True, index=True)
117+
title: Mapped[str] = mapped_column(index=True)
118+
author: Mapped[str] = mapped_column(index=True)
119+
description: Mapped[Optional[str]] = mapped_column(nullable=True)
120+
publisher_id: Mapped[Optional[int]] = mapped_column(
121+
ForeignKey("publishers.id"),
122+
nullable=True
123+
)
124+
publisher: Mapped[Optional["Publisher"]] = relationship(
125+
"Publisher",
126+
back_populates="books"
127+
)
128+
```
129+
130+
## Running the Migration
131+
132+
```bash
133+
# Install Codegen
134+
pip install codegen
135+
# Run the migration
136+
python run.py
137+
```
138+
139+
## Learn More
140+
141+
- [SQLAlchemy 2.0 Documentation](https://docs.sqlalchemy.org/en/20/)
142+
- [SQLAlchemy Type Annotations Guide](https://docs.sqlalchemy.org/en/20/orm/typing.html)
143+
- [Codegen Documentation](https://docs.codegen.com)
144+
145+
## Contributing
146+
147+
Feel free to submit issues and enhancement requests!
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# SQLAlchemy Type Notations Example
2+
3+
A minimal repository for testing SQLAlchemy type annotations and database patterns.
4+
5+
## Purpose
6+
7+
- Test SQLAlchemy type annotations
8+
- Experiment with database patterns
9+
- Quick prototyping environment
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import os
2+
3+
DATABASE_URL = os.getenv("DATABASE_URL", "postgresql://user:pass@localhost:5432/db")
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from sqlalchemy import create_engine
2+
from sqlalchemy.orm import sessionmaker
3+
from ..config.settings import DATABASE_URL
4+
5+
engine = create_engine(DATABASE_URL)
6+
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from sqlalchemy.ext.declarative import declarative_base
2+
from sqlalchemy.orm import Session
3+
4+
Base = declarative_base()
5+
6+
7+
def get_db() -> Session:
8+
# Placeholder for DB session creation
9+
pass
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from sqlalchemy.orm import Mapped
2+
3+
from datetime import datetime
4+
5+
from sqlalchemy import Column, Integer, String, DateTime
6+
from sqlalchemy.orm import relationship
7+
from .base import Base
8+
9+
10+
class Organization(Base):
11+
__tablename__ = "organizations"
12+
13+
id = Column(Integer, primary_key=True)
14+
name = Column(String(200))
15+
xero_organization_id = Column(String(50), unique=True)
16+
stripe_customer_id = Column(String(100))
17+
updated_at = Column(DateTime)
18+
19+
# Relationships
20+
users = relationship("User", back_populates="organization")
21+
transactions = relationship("Transaction", back_populates="organization")
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from sqlalchemy.orm import Mapped
2+
3+
from decimal import Decimal
4+
5+
from datetime import datetime
6+
7+
from sqlalchemy import Column, Integer, String, ForeignKey, Numeric, DateTime
8+
from sqlalchemy.orm import relationship
9+
from .base import Base
10+
11+
12+
class Transaction(Base):
13+
__tablename__ = "transactions"
14+
15+
id = Column(Integer, primary_key=True)
16+
amount = Column(Numeric(10, 2))
17+
description = Column(String(500))
18+
reference_id = Column(String(100))
19+
user_id = Column(Integer, ForeignKey("users.id"))
20+
organization_id = Column(Integer, ForeignKey("organizations.id"))
21+
created_at = Column(DateTime)
22+
23+
# Relationships
24+
user = relationship("User", back_populates="transactions")
25+
organization = relationship("Organization", back_populates="transactions")
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from sqlalchemy.orm import Mapped
2+
3+
from sqlalchemy import Column, Integer, String, ForeignKey, Boolean
4+
from sqlalchemy.orm import relationship
5+
from .base import Base
6+
7+
8+
class User(Base):
9+
__tablename__ = "users"
10+
11+
id = Column(Integer, primary_key=True)
12+
email = Column(String(255), unique=True)
13+
username = Column(String(100))
14+
is_active = Column(Boolean, default=True)
15+
organization_id = Column(Integer, ForeignKey("organizations.id"))
16+
17+
# Relationships
18+
organization = relationship("Organization", back_populates="users")
19+
transactions = relationship("Transaction", back_populates="user")

sqlalchemy_type_annotations/run.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import codegen
2+
from codegen import Codebase
3+
from codegen.sdk.core.detached_symbols.function_call import FunctionCall
4+
import subprocess
5+
import shutil
6+
import os
7+
8+
9+
def init_git_repo(repo_path: str) -> None:
10+
"""Initialize a git repository in the given path."""
11+
subprocess.run(["git", "init"], cwd=repo_path, check=True)
12+
subprocess.run(["git", "add", "."], cwd=repo_path, check=True)
13+
subprocess.run(["git", "commit", "-m", "Initial commit"], cwd=repo_path, check=True)
14+
15+
16+
def cleanup_git_repo(repo_path: str) -> None:
17+
"""Remove the .git directory from the given path."""
18+
git_dir = os.path.join(repo_path, ".git")
19+
if os.path.exists(git_dir):
20+
shutil.rmtree(git_dir)
21+
22+
23+
@codegen.function("sqlalchemy-type-annotations")
24+
def run(codebase: Codebase):
25+
"""Add Mapped types to SQLAlchemy models in a codebase.
26+
27+
This codemod:
28+
1. Finds all SQLAlchemy model classes
29+
2. Converts Column type annotations to Mapped types
30+
3. Adds necessary imports for the new type annotations
31+
"""
32+
# Define type mapping
33+
column_type_to_mapped_type = {
34+
"Integer": "Mapped[int]",
35+
"Optional[Integer]": "Mapped[int | None]",
36+
"Boolean": "Mapped[bool]",
37+
"Optional[Boolean]": "Mapped[bool | None]",
38+
"DateTime": "Mapped[datetime | None]",
39+
"Optional[DateTime]": "Mapped[datetime | None]",
40+
"String": "Mapped[str]",
41+
"Optional[String]": "Mapped[str | None]",
42+
"Numeric": "Mapped[Decimal]",
43+
"Optional[Numeric]": "Mapped[Decimal | None]",
44+
}
45+
46+
# Track statistics
47+
classes_modified = 0
48+
attributes_modified = 0
49+
50+
# Traverse the codebase classes
51+
for cls in codebase.classes:
52+
class_modified = False
53+
original_source = cls.source # Store original source before modifications
54+
55+
for attribute in cls.attributes:
56+
if not attribute.assignment:
57+
continue
58+
59+
assignment_value = attribute.assignment.value
60+
if not isinstance(assignment_value, FunctionCall):
61+
continue
62+
63+
if assignment_value.name != "Column":
64+
continue
65+
66+
db_column_call = assignment_value
67+
68+
# Make sure we have at least one argument (the type)
69+
if len(db_column_call.args) == 0:
70+
continue
71+
72+
# Check for nullable=True
73+
is_nullable = any(
74+
x.name == "nullable" and x.value == "True" for x in db_column_call.args
75+
)
76+
77+
# Extract the first argument for the column type
78+
first_argument = db_column_call.args[0].source or ""
79+
first_argument = first_argument.split("(")[0].strip()
80+
81+
# If the type is namespaced (e.g. sa.Integer), get the last part
82+
if "." in first_argument:
83+
first_argument = first_argument.split(".")[-1]
84+
85+
# If nullable, wrap the type in Optional[...]
86+
if is_nullable:
87+
first_argument = f"Optional[{first_argument}]"
88+
89+
# Check if we have a corresponding mapped type
90+
if first_argument not in column_type_to_mapped_type:
91+
print(f"Skipping unmapped type: {first_argument}")
92+
continue
93+
94+
# Build the new mapped type annotation
95+
new_type = column_type_to_mapped_type[first_argument]
96+
97+
# Update the assignment type annotation
98+
attribute.assignment.set_type_annotation(new_type)
99+
attributes_modified += 1
100+
class_modified = True
101+
102+
# Add necessary imports
103+
if not cls.file.has_import("Mapped"):
104+
cls.file.add_import_from_import_string(
105+
"from sqlalchemy.orm import Mapped\n"
106+
)
107+
108+
if "Optional" in new_type and not cls.file.has_import("Optional"):
109+
cls.file.add_import_from_import_string("from typing import Optional\n")
110+
111+
if "Decimal" in new_type and not cls.file.has_import("Decimal"):
112+
cls.file.add_import_from_import_string("from decimal import Decimal\n")
113+
114+
if "datetime" in new_type and not cls.file.has_import("datetime"):
115+
cls.file.add_import_from_import_string(
116+
"from datetime import datetime\n"
117+
)
118+
119+
if class_modified:
120+
classes_modified += 1
121+
# Print the diff for this class
122+
print(f"\nModified class: {cls.name}")
123+
print("Before:")
124+
print(original_source)
125+
print("\nAfter:")
126+
print(cls.source)
127+
print("-" * 80)
128+
129+
print("\nModification complete:")
130+
print(f"Classes modified: {classes_modified}")
131+
print(f"Attributes modified: {attributes_modified}")
132+
133+
134+
if __name__ == "__main__":
135+
input_repo = "./input_repo"
136+
print("Initializing git repository...")
137+
init_git_repo(input_repo)
138+
139+
print("Initializing codebase...")
140+
codebase = Codebase(input_repo)
141+
142+
print("Running codemod...")
143+
run(codebase)
144+
145+
print("Cleaning up git repository...")
146+
cleanup_git_repo(input_repo)

0 commit comments

Comments
 (0)