@@ -812,7 +812,8 @@ def psql(self,
812812 filename = None ,
813813 dbname = None ,
814814 username = None ,
815- input = None ):
815+ input = None ,
816+ ** variables ):
816817 """
817818 Execute a query using psql.
818819
@@ -822,9 +823,18 @@ def psql(self,
822823 dbname: database name to connect to.
823824 username: database user name.
824825 input: raw input to be passed.
826+ **variables: vars to be set before execution.
825827
826828 Returns:
827829 A tuple of (code, stdout, stderr).
830+
831+ Examples:
832+ >>> psql('select 1')
833+ (0, b'1\n ', b'')
834+ >>> psql('postgres', 'select 2')
835+ (0, b'2\n ', b'')
836+ >>> psql(query='select 3', ON_ERROR_STOP=1)
837+ (0, b'3\n ', b'')
828838 """
829839
830840 # Set default arguments
@@ -843,6 +853,10 @@ def psql(self,
843853 dbname
844854 ] # yapf: disable
845855
856+ # set variables before execution
857+ for key , value in iteritems (variables ):
858+ psql_params .extend (["--set" , '{}={}' .format (key , value )])
859+
846860 # select query source
847861 if query :
848862 psql_params .extend (("-c" , query ))
@@ -874,10 +888,15 @@ def safe_psql(self, query=None, **kwargs):
874888 username: database user name.
875889 input: raw input to be passed.
876890
891+ **kwargs are passed to psql().
892+
877893 Returns:
878894 psql's output as str.
879895 """
880896
897+ # force this setting
898+ kwargs ['ON_ERROR_STOP' ] = 1
899+
881900 ret , out , err = self .psql (query = query , ** kwargs )
882901 if ret :
883902 raise QueryException ((err or b'' ).decode ('utf-8' ), query )
0 commit comments