|
1 | 1 | import codegen |
2 | 2 | from codegen import Codebase |
3 | 3 | from codegen.sdk.core.detached_symbols.function_call import FunctionCall |
4 | | -from codegen.sdk.codebase.config import CodebaseConfig, GSFeatureFlags |
5 | | - |
6 | | -codebase = Codebase("./input_repo", config=CodebaseConfig(feature_flags=GSFeatureFlags(disable_graph=True))) |
7 | | - |
8 | | -# Values for soft delete models and join methods |
9 | | -soft_delete_models = { |
10 | | - "User", |
11 | | - "ProductWorkflow", |
12 | | - "TransactionCanonical", |
13 | | - "BillParametersLogEntry", |
14 | | - "SpendEventCanonical", |
15 | | - "TrackingCategory", |
16 | | - "Payee", |
17 | | - "Card", |
18 | | - "ApprovalInstance", |
19 | | - "Merchant", |
20 | | - "Transaction", |
21 | | -} |
22 | | -join_methods = {"join", "outerjoin", "innerjoin"} |
23 | | - |
24 | | -# Loop through all files and function calls |
25 | | -for file in codebase.files: |
26 | | - for call in file.function_calls: |
27 | | - # Get the arguments as a list |
28 | | - call_args = list(call.args) |
29 | | - |
30 | | - # Skip if the function call is not a join method |
31 | | - if str(call.name) not in join_methods: |
32 | | - continue |
33 | | - |
34 | | - # Skip if the function call has no arguments |
35 | | - if len(call_args) == 0: |
36 | | - continue |
37 | | - |
38 | | - # Get the model name from the first argument |
39 | | - model_name = str(call_args[0].value) |
40 | | - |
41 | | - # Skip if the model name is not in the soft delete models |
42 | | - if model_name not in soft_delete_models: |
43 | | - continue |
44 | | - |
45 | | - # Construct the deleted_at check expression |
46 | | - print(f"Found join method for model {model_name} in file {file.filepath}") |
47 | | - deleted_at_check = f"{model_name}.deleted_at.is_(None)" |
48 | | - |
49 | | - # If there is only one argument, add the deleted_at check |
50 | | - if len(call_args) == 1: |
51 | | - print(f"Adding deleted_at check to function call {call.source}") |
52 | | - call_args.append(deleted_at_check) |
53 | | - elif len(call_args) >= 2: |
54 | | - # Get the second argument |
55 | | - second_arg = call_args[1].value |
56 | | - |
57 | | - # Skip if the second argument is already the deleted_at check |
58 | | - if second_arg.source == deleted_at_check: |
59 | | - print(f"Skipping {file.filepath} because the deleted_at check is already present") |
| 4 | +from codegen.sdk.enums import ProgrammingLanguage |
| 5 | + |
| 6 | + |
| 7 | +def should_process_join_call(call, soft_delete_models, join_methods): |
| 8 | + """Determine if a function call should be processed for soft delete conditions.""" |
| 9 | + if str(call.name) not in join_methods: |
| 10 | + return False |
| 11 | + |
| 12 | + call_args = list(call.args) |
| 13 | + if not call_args: |
| 14 | + return False |
| 15 | + |
| 16 | + model_name = str(call_args[0].value) |
| 17 | + return model_name in soft_delete_models |
| 18 | + |
| 19 | + |
| 20 | +def add_deleted_at_check(file, call, model_name): |
| 21 | + """Add the deleted_at check to a join call.""" |
| 22 | + call_args = list(call.args) |
| 23 | + deleted_at_check = f"{model_name}.deleted_at.is_(None)" |
| 24 | + |
| 25 | + if len(call_args) == 1: |
| 26 | + print(f"Adding deleted_at check to function call {call.source}") |
| 27 | + call_args.append(deleted_at_check) |
| 28 | + return |
| 29 | + |
| 30 | + second_arg = call_args[1].value |
| 31 | + if second_arg.source == deleted_at_check: |
| 32 | + print(f"Skipping {file.filepath} because the deleted_at check is already present") |
| 33 | + return |
| 34 | + |
| 35 | + if isinstance(second_arg, FunctionCall) and second_arg.name == "and_": |
| 36 | + if deleted_at_check in {str(x) for x in second_arg.args}: |
| 37 | + print(f"Skipping {file.filepath} because the deleted_at check is already present") |
| 38 | + return |
| 39 | + print(f"Adding deleted_at check to and_ call in {file.filepath}") |
| 40 | + second_arg.args.append(deleted_at_check) |
| 41 | + else: |
| 42 | + print(f"Adding deleted_at check to {file.filepath}") |
| 43 | + call_args[1].edit(f"and_({second_arg.source}, {deleted_at_check})") |
| 44 | + |
| 45 | + ensure_and_import(file) |
| 46 | + |
| 47 | + |
| 48 | +def ensure_and_import(file): |
| 49 | + """Ensure the file has the necessary and_ import.""" |
| 50 | + if not any("and_" in imp.name for imp in file.imports): |
| 51 | + print(f"File {file.filepath} does not import and_. Adding import.") |
| 52 | + file.add_import_from_import_string("from sqlalchemy import and_") |
| 53 | + |
| 54 | + |
| 55 | +@codegen.function("sqlalchemy-soft-delete") |
| 56 | +def process_soft_deletes(codebase): |
| 57 | + """Process soft delete conditions for join methods in the codebase.""" |
| 58 | + soft_delete_models = { |
| 59 | + "User", |
| 60 | + "Update", |
| 61 | + "Proposal", |
| 62 | + "Comment", |
| 63 | + "Project", |
| 64 | + "Team", |
| 65 | + "SavedSession", |
| 66 | + } |
| 67 | + join_methods = {"join", "outerjoin", "innerjoin"} |
| 68 | + |
| 69 | + for file in codebase.files: |
| 70 | + for call in file.function_calls: |
| 71 | + if not should_process_join_call(call, soft_delete_models, join_methods): |
60 | 72 | continue |
61 | 73 |
|
62 | | - # If the second argument is an and_ call, add the deleted_at check if it's not already present |
63 | | - if isinstance(second_arg, FunctionCall) and second_arg.name == "and_": |
64 | | - if deleted_at_check in {str(x) for x in second_arg.args}: |
65 | | - print(f"Skipping {file.filepath} because the deleted_at check is already present") |
66 | | - continue |
67 | | - else: |
68 | | - print(f"Adding deleted_at check to and_ call in {file.filepath}") |
69 | | - second_arg.args.append(deleted_at_check) |
70 | | - else: |
71 | | - print(f"Adding deleted_at check to {file.filepath}") |
72 | | - call_args[1].edit(f"and_({second_arg.source}, {deleted_at_check})") |
73 | | - |
74 | | - # Check if the file imports and_ |
75 | | - if any("and_" in imp.name for imp in file.imports): |
76 | | - print(f"File {file.filepath} imports and_") |
77 | | - else: |
78 | | - print(f"File {file.filepath} does not import and_. Adding import.") |
79 | | - file.add_import_from_import_string("from sqlalchemy import and_") |
| 74 | + model_name = str(list(call.args)[0].value) |
| 75 | + print(f"Found join method for model {model_name} in file {file.filepath}") |
| 76 | + add_deleted_at_check(file, call, model_name) |
| 77 | + |
| 78 | + print("commit") |
| 79 | + print(codebase.get_diff()) |
| 80 | + |
| 81 | + |
| 82 | +if __name__ == "__main__": |
| 83 | + codebase = Codebase.from_repo("hasgeek/funnel", programming_language=ProgrammingLanguage.PYTHON) |
| 84 | + print(codebase.files) |
| 85 | + process_soft_deletes(codebase) |
0 commit comments