33from sqlalchemy import URL , create_engine , exc , text
44
55from patchwork .common .utils .utils import mustache_render
6+ from patchwork .logger import logger
67from patchwork .step import Step , StepStatus
78from patchwork .steps .CallSQL .typed import CallSQLInputs , CallSQLOutputs
89
910
1011class CallSQL (Step , input_class = CallSQLInputs , output_class = CallSQLOutputs ):
1112 def __init__ (self , inputs : dict ):
1213 super ().__init__ (inputs )
13- query_template_data = inputs .get ("query_template_values " , {})
14- self .query = mustache_render (inputs ["query " ], query_template_data )
14+ query_template_data = inputs .get ("db_query_template_values " , {})
15+ self .query = mustache_render (inputs ["db_query " ], query_template_data )
1516 self .__build_engine (inputs )
1617
1718 def __build_engine (self , inputs : dict ):
18- dialect = inputs ["dialect " ]
19- driver = inputs .get ("driver " )
19+ dialect = inputs ["db_dialect " ]
20+ driver = inputs .get ("db_driver " )
2021 dialect_plus_driver = f"{ dialect } +{ driver } " if driver is not None else dialect
2122 kwargs = dict (
22- username = inputs ["username " ],
23- host = inputs .get ("host " , "localhost" ),
24- port = inputs .get ("port " , 5432 ),
23+ username = inputs ["db_username " ],
24+ host = inputs .get ("db_host " , "localhost" ),
25+ port = inputs .get ("db_port " , 5432 ),
2526 )
26- if inputs .get ("password" ) is not None :
27- kwargs ["password" ] = inputs .get ("password" )
27+ if inputs .get ("db_password" ) is not None :
28+ kwargs ["password" ] = inputs .get ("db_password" )
29+ if inputs .get ("db_name" ) is not None :
30+ kwargs ["database" ] = inputs .get ("db_name" )
31+ if inputs .get ("db_params" ) is not None :
32+ kwargs ["query" ] = inputs .get ("db_params" )
2833 connection_url = URL .create (
2934 dialect_plus_driver ,
3035 ** kwargs ,
@@ -36,10 +41,14 @@ def __build_engine(self, inputs: dict):
3641
3742 def run (self ) -> dict :
3843 try :
44+ rv = []
3945 with self .engine .begin () as conn :
4046 cursor = conn .execute (text (self .query ))
41- result = cursor .fetchall ()
42- return dict (result = result )
47+ for row in cursor :
48+ result = row ._asdict ()
49+ rv .append (result )
50+ logger .info (f"Retrieved { len (rv )} rows!" )
51+ return dict (results = rv )
4352 except exc .InvalidRequestError as e :
4453 self .set_status (StepStatus .FAILED , f"`{ self .query } ` failed with message:\n { e } " )
45- return dict (result = [])
54+ return dict (results = [])
0 commit comments