33import sqlparse
44from .packages import special
55from pymysql .constants import FIELD_TYPE
6- from pymysql .converters import (convert_mysql_timestamp , convert_datetime ,
6+ from pymysql .converters import (convert_datetime ,
77 convert_timedelta , convert_date , conversions ,
88 decoders )
99try :
@@ -42,7 +42,7 @@ class SQLExecute(object):
4242
4343 def __init__ (self , database , user , password , host , port , socket , charset ,
4444 local_infile , ssl , ssh_user , ssh_host , ssh_port , ssh_password ,
45- ssh_key_filename ):
45+ ssh_key_filename , init_command = None ):
4646 self .dbname = database
4747 self .user = user
4848 self .password = password
@@ -59,12 +59,13 @@ def __init__(self, database, user, password, host, port, socket, charset,
5959 self .ssh_port = ssh_port
6060 self .ssh_password = ssh_password
6161 self .ssh_key_filename = ssh_key_filename
62+ self .init_command = init_command
6263 self .connect ()
6364
6465 def connect (self , database = None , user = None , password = None , host = None ,
6566 port = None , socket = None , charset = None , local_infile = None ,
6667 ssl = None , ssh_host = None , ssh_port = None , ssh_user = None ,
67- ssh_password = None , ssh_key_filename = None ):
68+ ssh_password = None , ssh_key_filename = None , init_command = None ):
6869 db = (database or self .dbname )
6970 user = (user or self .user )
7071 password = (password or self .password )
@@ -79,6 +80,7 @@ def connect(self, database=None, user=None, password=None, host=None,
7980 ssh_port = (ssh_port or self .ssh_port )
8081 ssh_password = (ssh_password or self .ssh_password )
8182 ssh_key_filename = (ssh_key_filename or self .ssh_key_filename )
83+ init_command = (init_command or self .init_command )
8284 _logger .debug (
8385 'Connection DB Params: \n '
8486 '\t database: %r'
@@ -93,13 +95,15 @@ def connect(self, database=None, user=None, password=None, host=None,
9395 '\t ssh_host: %r'
9496 '\t ssh_port: %r'
9597 '\t ssh_password: %r'
96- '\t ssh_key_filename: %r' ,
98+ '\t ssh_key_filename: %r'
99+ '\t init_command: %r' ,
97100 db , user , host , port , socket , charset , local_infile , ssl ,
98- ssh_user , ssh_host , ssh_port , ssh_password , ssh_key_filename
101+ ssh_user , ssh_host , ssh_port , ssh_password , ssh_key_filename ,
102+ init_command
99103 )
100104 conv = conversions .copy ()
101105 conv .update ({
102- FIELD_TYPE .TIMESTAMP : lambda obj : (convert_mysql_timestamp (obj ) or obj ),
106+ FIELD_TYPE .TIMESTAMP : lambda obj : (convert_datetime (obj ) or obj ),
103107 FIELD_TYPE .DATETIME : lambda obj : (convert_datetime (obj ) or obj ),
104108 FIELD_TYPE .TIME : lambda obj : (convert_timedelta (obj ) or obj ),
105109 FIELD_TYPE .DATE : lambda obj : (convert_date (obj ) or obj ),
@@ -110,12 +114,16 @@ def connect(self, database=None, user=None, password=None, host=None,
110114 if ssh_host :
111115 defer_connect = True
112116
117+ client_flag = pymysql .constants .CLIENT .INTERACTIVE
118+ if init_command and len (list (special .split_queries (init_command ))) > 1 :
119+ client_flag |= pymysql .constants .CLIENT .MULTI_STATEMENTS
120+
113121 conn = pymysql .connect (
114122 database = db , user = user , password = password , host = host , port = port ,
115123 unix_socket = socket , use_unicode = True , charset = charset ,
116- autocommit = True , client_flag = pymysql . constants . CLIENT . INTERACTIVE ,
124+ autocommit = True , client_flag = client_flag ,
117125 local_infile = local_infile , conv = conv , ssl = ssl , program_name = "mycli" ,
118- defer_connect = defer_connect
126+ defer_connect = defer_connect , init_command = init_command
119127 )
120128
121129 if ssh_host :
@@ -146,6 +154,7 @@ def connect(self, database=None, user=None, password=None, host=None,
146154 self .socket = socket
147155 self .charset = charset
148156 self .ssl = ssl
157+ self .init_command = init_command
149158 # retrieve connection id
150159 self .reset_connection_id ()
151160
0 commit comments