diff --git a/aiomysql/connection.py b/aiomysql/connection.py index 8bf5cbf4..7f5085b4 100644 --- a/aiomysql/connection.py +++ b/aiomysql/connection.py @@ -15,7 +15,6 @@ from pymysql.constants import SERVER_STATUS from pymysql.constants import CLIENT from pymysql.constants import COMMAND -from pymysql.constants import FIELD_TYPE from pymysql.util import byte2int, int2byte from pymysql.converters import (escape_item, encoders, decoders, escape_string, escape_bytes_prefixed, through) @@ -25,7 +24,7 @@ IntegrityError, InternalError, NotSupportedError, ProgrammingError) -from pymysql.connections import TEXT_TYPES, MAX_PACKET_LEN, DEFAULT_CHARSET +from pymysql.connections import MAX_PACKET_LEN, DEFAULT_CHARSET from pymysql.connections import _auth from pymysql.connections import pack_int24 @@ -39,8 +38,9 @@ # from aiomysql.utils import _convert_to_str from .cursors import Cursor -from .utils import _ConnectionContextManager, _ContextManager +from .utils import _ConnectionContextManager, _ContextManager, _decide_encoding from .log import logger +from .prepared_statement import PreparedStatement DEFAULT_USER = getpass.getuser() @@ -971,6 +971,38 @@ async def sha256_password_auth(self, pkt): pkt.check_error() return pkt + async def prepare(self, sql): + await self._execute_command(COMMAND.COM_STMT_PREPARE, sql) + packet = await self._read_packet() + if not packet.is_ok_packet(): + raise Error("Unexpected error") + + # status + packet.advance(1) + statement_id = packet.read_uint32() + num_columns = packet.read_uint16() + num_params = packet.read_uint16() + # reserved + packet.advance(1) + # warning count + packet.read_uint16() + params = [] + columns = [] + if num_params > 0: + for _ in range(num_params): + params.append(await self._read_packet(FieldDescriptorPacket)) + if self.client_flag | CLIENT.PROTOCOL_41: + # EOF + await self._read_packet() + if num_columns > 0: + for _ in range(num_columns): + columns.append(await self._read_packet(FieldDescriptorPacket)) + if self.client_flag | CLIENT.PROTOCOL_41: + # EOF + await self._read_packet() + + return PreparedStatement(self, statement_id, params, columns) + # _mysql support def thread_id(self): return self.server_thread_id[0] @@ -1244,30 +1276,8 @@ async def _get_descriptions(self): FieldDescriptorPacket) self.fields.append(field) description.append(field.description()) - field_type = field.type_code - if use_unicode: - if field_type == FIELD_TYPE.JSON: - # When SELECT from JSON column: charset = binary - # When SELECT CAST(... AS JSON): charset = connection - # encoding - # This behavior is different from TEXT / BLOB. - # We should decode result by connection encoding - # regardless charsetnr. - # See https://github.com/PyMySQL/PyMySQL/issues/488 - encoding = conn_encoding # SELECT CAST(... AS JSON) - elif field_type in TEXT_TYPES: - if field.charsetnr == 63: # binary - # TEXTs with charset=binary means BINARY types. - encoding = None - else: - encoding = conn_encoding - else: - # Integers, Dates and Times, and other basic data - # is encoded in ascii - encoding = 'ascii' - else: - encoding = None - converter = self.connection.decoders.get(field_type) + encoding = _decide_encoding(use_unicode, conn_encoding, field) + converter = self.connection.decoders.get(field.type_code) if converter is through: converter = None self.converters.append((encoding, converter)) diff --git a/aiomysql/cursors.py b/aiomysql/cursors.py index 97f0431f..c493d75f 100644 --- a/aiomysql/cursors.py +++ b/aiomysql/cursors.py @@ -3,13 +3,13 @@ import warnings import contextlib +from pymysql.constants import FIELD_TYPE from pymysql.err import ( Warning, Error, InterfaceError, DataError, DatabaseError, OperationalError, IntegrityError, InternalError, NotSupportedError, ProgrammingError) from .log import logger -from .connection import FIELD_TYPE # https://github.com/PyMySQL/PyMySQL/blob/master/pymysql/cursors.py#L11-L18 diff --git a/aiomysql/prepared_statement.py b/aiomysql/prepared_statement.py new file mode 100644 index 00000000..3971adb9 --- /dev/null +++ b/aiomysql/prepared_statement.py @@ -0,0 +1,294 @@ +import datetime +from decimal import Decimal +import struct + +from pymysql.connections import (FieldDescriptorPacket, MysqlPacket) +from pymysql.constants import (COMMAND, FIELD_TYPE) +from pymysql.converters import through +from pymysql.err import Error + +from .utils import _decide_encoding + + +class PreparedStatement(object): + def __init__(self, connection, stmt_id, params, columns): + self.connection = connection + self.stmt_id = stmt_id + self.params = params + self.columns = columns + self._rows = None + self._rownumber = 0 + self._rowcount = 0 + + async def execute(self, *args): + if len(args) != len(self.params): + raise Error("argument count doesn't match") + self.connection._next_seq_id = 0 + self._rownumber = 0 + data = struct.pack("!B", COMMAND.COM_STMT_EXECUTE) + data += struct.pack("> 8) + elif size <= 0xffffffff: + return b"\xfe" + struct.pack("> 3 + null_bitmap = self.read(null_bitmap_len) + result = [] + for i, c in enumerate(self._columns): + if null_bitmap[(i + 2) >> 3] & (1 << ((i + 2) % 8)): + result.append(None) + continue + # https://dev.mysql.com/doc/internals/en/binary-protocol-value.html + if c.type_code in _string_types: + n, is_none = self._read_length_encoded_integer() + if is_none: + result.append(None) + continue + data = self.read(n) + encoding, converter = self._converters[i] + if encoding is not None: + data = data.decode(encoding) + if converter is not None: + data = converter(data) + result.append(data) + continue + if c.type_code == FIELD_TYPE.LONGLONG: + result.append(self.read_struct("