Skip to content

Commit 4082ee0

Browse files
committed
.
1 parent 914834a commit 4082ee0

File tree

1 file changed

+80
-74
lines changed
  • examples/sqlalchemy_soft_delete

1 file changed

+80
-74
lines changed
Lines changed: 80 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,79 +1,85 @@
11
import codegen
22
from codegen import Codebase
33
from codegen.sdk.core.detached_symbols.function_call import FunctionCall
4-
from codegen.sdk.codebase.config import CodebaseConfig, GSFeatureFlags
5-
6-
codebase = Codebase("./input_repo", config=CodebaseConfig(feature_flags=GSFeatureFlags(disable_graph=True)))
7-
8-
# Values for soft delete models and join methods
9-
soft_delete_models = {
10-
"User",
11-
"ProductWorkflow",
12-
"TransactionCanonical",
13-
"BillParametersLogEntry",
14-
"SpendEventCanonical",
15-
"TrackingCategory",
16-
"Payee",
17-
"Card",
18-
"ApprovalInstance",
19-
"Merchant",
20-
"Transaction",
21-
}
22-
join_methods = {"join", "outerjoin", "innerjoin"}
23-
24-
# Loop through all files and function calls
25-
for file in codebase.files:
26-
for call in file.function_calls:
27-
# Get the arguments as a list
28-
call_args = list(call.args)
29-
30-
# Skip if the function call is not a join method
31-
if str(call.name) not in join_methods:
32-
continue
33-
34-
# Skip if the function call has no arguments
35-
if len(call_args) == 0:
36-
continue
37-
38-
# Get the model name from the first argument
39-
model_name = str(call_args[0].value)
40-
41-
# Skip if the model name is not in the soft delete models
42-
if model_name not in soft_delete_models:
43-
continue
44-
45-
# Construct the deleted_at check expression
46-
print(f"Found join method for model {model_name} in file {file.filepath}")
47-
deleted_at_check = f"{model_name}.deleted_at.is_(None)"
48-
49-
# If there is only one argument, add the deleted_at check
50-
if len(call_args) == 1:
51-
print(f"Adding deleted_at check to function call {call.source}")
52-
call_args.append(deleted_at_check)
53-
elif len(call_args) >= 2:
54-
# Get the second argument
55-
second_arg = call_args[1].value
56-
57-
# Skip if the second argument is already the deleted_at check
58-
if second_arg.source == deleted_at_check:
59-
print(f"Skipping {file.filepath} because the deleted_at check is already present")
4+
from codegen.sdk.enums import ProgrammingLanguage
5+
6+
7+
def should_process_join_call(call, soft_delete_models, join_methods):
8+
"""Determine if a function call should be processed for soft delete conditions."""
9+
if str(call.name) not in join_methods:
10+
return False
11+
12+
call_args = list(call.args)
13+
if not call_args:
14+
return False
15+
16+
model_name = str(call_args[0].value)
17+
return model_name in soft_delete_models
18+
19+
20+
def add_deleted_at_check(file, call, model_name):
21+
"""Add the deleted_at check to a join call."""
22+
call_args = list(call.args)
23+
deleted_at_check = f"{model_name}.deleted_at.is_(None)"
24+
25+
if len(call_args) == 1:
26+
print(f"Adding deleted_at check to function call {call.source}")
27+
call_args.append(deleted_at_check)
28+
return
29+
30+
second_arg = call_args[1].value
31+
if second_arg.source == deleted_at_check:
32+
print(f"Skipping {file.filepath} because the deleted_at check is already present")
33+
return
34+
35+
if isinstance(second_arg, FunctionCall) and second_arg.name == "and_":
36+
if deleted_at_check in {str(x) for x in second_arg.args}:
37+
print(f"Skipping {file.filepath} because the deleted_at check is already present")
38+
return
39+
print(f"Adding deleted_at check to and_ call in {file.filepath}")
40+
second_arg.args.append(deleted_at_check)
41+
else:
42+
print(f"Adding deleted_at check to {file.filepath}")
43+
call_args[1].edit(f"and_({second_arg.source}, {deleted_at_check})")
44+
45+
ensure_and_import(file)
46+
47+
48+
def ensure_and_import(file):
49+
"""Ensure the file has the necessary and_ import."""
50+
if not any("and_" in imp.name for imp in file.imports):
51+
print(f"File {file.filepath} does not import and_. Adding import.")
52+
file.add_import_from_import_string("from sqlalchemy import and_")
53+
54+
55+
@codegen.function("sqlalchemy-soft-delete")
56+
def process_soft_deletes(codebase):
57+
"""Process soft delete conditions for join methods in the codebase."""
58+
soft_delete_models = {
59+
"User",
60+
"Update",
61+
"Proposal",
62+
"Comment",
63+
"Project",
64+
"Team",
65+
"SavedSession",
66+
}
67+
join_methods = {"join", "outerjoin", "innerjoin"}
68+
69+
for file in codebase.files:
70+
for call in file.function_calls:
71+
if not should_process_join_call(call, soft_delete_models, join_methods):
6072
continue
6173

62-
# If the second argument is an and_ call, add the deleted_at check if it's not already present
63-
if isinstance(second_arg, FunctionCall) and second_arg.name == "and_":
64-
if deleted_at_check in {str(x) for x in second_arg.args}:
65-
print(f"Skipping {file.filepath} because the deleted_at check is already present")
66-
continue
67-
else:
68-
print(f"Adding deleted_at check to and_ call in {file.filepath}")
69-
second_arg.args.append(deleted_at_check)
70-
else:
71-
print(f"Adding deleted_at check to {file.filepath}")
72-
call_args[1].edit(f"and_({second_arg.source}, {deleted_at_check})")
73-
74-
# Check if the file imports and_
75-
if any("and_" in imp.name for imp in file.imports):
76-
print(f"File {file.filepath} imports and_")
77-
else:
78-
print(f"File {file.filepath} does not import and_. Adding import.")
79-
file.add_import_from_import_string("from sqlalchemy import and_")
74+
model_name = str(list(call.args)[0].value)
75+
print(f"Found join method for model {model_name} in file {file.filepath}")
76+
add_deleted_at_check(file, call, model_name)
77+
78+
print("commit")
79+
print(codebase.get_diff())
80+
81+
82+
if __name__ == "__main__":
83+
codebase = Codebase.from_repo("hasgeek/funnel", programming_language=ProgrammingLanguage.PYTHON)
84+
print(codebase.files)
85+
process_soft_deletes(codebase)

0 commit comments

Comments
 (0)