diff --git a/.codacy.yml b/.codacy.yml index cab97400..1e5a56b1 100644 --- a/.codacy.yml +++ b/.codacy.yml @@ -1,3 +1,4 @@ exclude_paths: - 'tests/**' + - 'mysql_tests/**' - 'docs/**' diff --git a/.coveragerc b/.coveragerc index df037533..3aa9ea3b 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,6 +1,8 @@ [run] source = ./src +omit = + ./src/gino/aiocontextvars.py [report] exclude_lines = pragma: no cover diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e2077822..89ad2805 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -66,21 +66,127 @@ jobs: DB_HOST: localhost DB_USER: gino run: | - $HOME/.poetry/bin/poetry run pytest --cov=src --cov-fail-under=95 --cov-report xml + $HOME/.poetry/bin/poetry run pytest tests/ + test-mysql: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [ '3.5', '3.6', '3.7', '3.8' ] + deps-version: [ 'lowest', 'highest' ] + services: + mysql: + image: mysql:5 + env: + MYSQL_ALLOW_EMPTY_PASSWORD: 1 + ports: + - 3306:3306 + options: --health-cmd="mysqladmin ping" --health-interval=10s --health-timeout=5s --health-retries=3 + steps: + - name: Checkout source code + uses: actions/checkout@v1 + - name: Set up Python + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python-version }} + - name: virtualenv cache + uses: actions/cache@preview + with: + path: ~/.cache/pypoetry/virtualenvs + key: ${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.deps-version }}-venv-${{ hashFiles(format('{0}{1}', github.workspace, '/poetry.lock')) }} + restore-keys: | + ${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.deps-version }}-venv- + - name: Poetry cache + uses: actions/cache@preview + with: + path: ~/.poetry + key: ${{ runner.os }}-${{ matrix.python-version }}-dotpoetry + restore-keys: | + ${{ runner.os }}-${{ matrix.python-version }}-dotpoetry- + - name: Install Python dependencies + run: | + curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/get-poetry.py | python + $HOME/.poetry/bin/poetry install --no-interaction + - name: Use lowest dependencies versions + if: matrix.deps-version == 'lowest' + run: | + $HOME/.poetry/bin/poetry run pip install asyncpg==0.18 SQLAlchemy==1.3 + - name: List installed packages + run: | + $HOME/.poetry/bin/poetry run pip list + - name: Test with pytest + env: + MYSQL_DB_HOST: 127.0.0.1 + MYSQL_DB_USER: root + run: | + $HOME/.poetry/bin/poetry run pytest mysql_tests/ + summary: + runs-on: ubuntu-latest + services: + postgres: + image: fantix/postgres-ssl:12.1 + env: + POSTGRES_USER: gino + ports: + - 5432:5432 + # needed because the postgres container does not provide a healthcheck + options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 + mysql: + image: mysql:5 + env: + MYSQL_ALLOW_EMPTY_PASSWORD: 1 + ports: + - 3306:3306 + options: --health-cmd="mysqladmin ping" --health-interval=10s --health-timeout=5s --health-retries=3 + steps: + - name: Checkout source code + uses: actions/checkout@v1 + - name: Set up Python + uses: actions/setup-python@v1 + with: + python-version: 3.8 + - name: virtualenv cache + uses: actions/cache@preview + with: + path: ~/.cache/pypoetry/virtualenvs + key: ${{ runner.os }}-3.8-highest-venv-${{ hashFiles(format('{0}{1}', github.workspace, '/poetry.lock')) }} + restore-keys: | + ${{ runner.os }}-3.8-highest-venv- + - name: Poetry cache + uses: actions/cache@preview + with: + path: ~/.poetry + key: ${{ runner.os }}-3.8-dotpoetry + restore-keys: | + ${{ runner.os }}-3.8-dotpoetry- + - name: Install Python dependencies + run: | + curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/get-poetry.py | python + $HOME/.poetry/bin/poetry install --no-interaction + - name: List installed packages + run: | + $HOME/.poetry/bin/poetry run pip list + - name: Test with pytest + env: + DB_HOST: localhost + DB_USER: gino + MYSQL_DB_HOST: 127.0.0.1 + MYSQL_DB_USER: root + run: | + $HOME/.poetry/bin/poetry run pytest --cov=src --cov-fail-under=95 --cov-report xml tests/ mysql_tests/ - name: Check code format with black - if: matrix.python-version >= '3.6' run: | $HOME/.poetry/bin/poetry run black --check src - name: Submit coverage report - if: matrix.python-version == '3.8' && matrix.postgres-version == '12.1' && matrix.deps-version == 'highest' && github.ref == 'refs/heads/master' + if: github.ref == 'refs/heads/master' env: CODACY_PROJECT_TOKEN: ${{ secrets.CODACY_TOKEN }} run: | pip install codacy-coverage python-codacy-coverage -r coverage.xml + release: runs-on: ubuntu-latest - needs: test + needs: summary strategy: matrix: python-version: [ '3.8' ] diff --git a/mysql_tests/__init__.py b/mysql_tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mysql_tests/conftest.py b/mysql_tests/conftest.py new file mode 100644 index 00000000..1c72f269 --- /dev/null +++ b/mysql_tests/conftest.py @@ -0,0 +1,59 @@ +import ssl + +import aiomysql +import pytest +import sqlalchemy +from async_generator import yield_, async_generator + +import gino +from .models import db, DB_ARGS, MYSQL_URL, random_name + +ECHO = False + + +@pytest.fixture(scope="module") +def sa_engine(): + rv = sqlalchemy.create_engine(MYSQL_URL, echo=ECHO) + db.create_all(rv) + yield rv + db.drop_all(rv) + rv.dispose() + + +@pytest.fixture +@async_generator +async def engine(sa_engine): + e = await gino.create_engine(MYSQL_URL, echo=ECHO, minsize=10) + await yield_(e) + await e.close() + sa_engine.execute("DELETE FROM gino_user_settings") + sa_engine.execute("DELETE FROM gino_users") + + +# noinspection PyUnusedLocal,PyShadowingNames +@pytest.fixture +@async_generator +async def bind(sa_engine): + async with db.with_bind(MYSQL_URL, echo=ECHO, minsize=10) as e: + await yield_(e) + sa_engine.execute("DELETE FROM gino_user_settings") + sa_engine.execute("DELETE FROM gino_users") + + +# noinspection PyUnusedLocal,PyShadowingNames +@pytest.fixture +@async_generator +async def aiomysql_pool(sa_engine): + async with aiomysql.create_pool(**DB_ARGS) as rv: + await yield_(rv) + async with rv.acquire() as conn: + await conn.query("DELETE FROM gino_user_settings") + await conn.query("DELETE FROM gino_users") + + +@pytest.fixture +def ssl_ctx(): + ctx = ssl.create_default_context() + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + return ctx diff --git a/mysql_tests/models.py b/mysql_tests/models.py new file mode 100644 index 00000000..52dc2484 --- /dev/null +++ b/mysql_tests/models.py @@ -0,0 +1,157 @@ +import os +import enum +import random +import string +from datetime import datetime + +import aiomysql +import asyncpg +import pytest + +from gino import Gino +from gino.dialects.aiomysql import JSON + +DB_ARGS = dict( + host=os.getenv("MYSQL_DB_HOST", "localhost"), + port=os.getenv("MYSQL_DB_PORT", 3306), + user=os.getenv("MYSQL_DB_USER", "root"), + password=os.getenv("MYSQL_DB_PASS", ""), + db=os.getenv("MYSQL_DB_NAME", "mysql"), +) +MYSQL_URL = "mysql://{user}:{password}@{host}:{port}/{db}".format(**DB_ARGS) +db = Gino() + + +@pytest.fixture +def random_name(length=8) -> str: + return _random_name(length) + + +def _random_name(length=8): + return "".join(random.choice(string.ascii_letters) for _ in range(length)) + + +class UserType(enum.Enum): + USER = "USER" + + +class User(db.Model): + __tablename__ = "gino_users" + + id = db.Column(db.BigInteger(), primary_key=True) + nickname = db.Column("name", db.Unicode(255), default=_random_name) + profile = db.Column("props", JSON(), nullable=False, default="{}") + type = db.Column(db.Enum(UserType), nullable=False, default=UserType.USER) + realname = db.StringProperty() + age = db.IntegerProperty(default=18) + balance = db.IntegerProperty(default=0) + birthday = db.DateTimeProperty(default=lambda i: datetime.utcfromtimestamp(0)) + team_id = db.Column(db.ForeignKey("gino_teams.id")) + + @balance.after_get + def balance(self, val): + if val is None: + return 0.0 + return float(val) + + def __repr__(self): + return "{}<{}>".format(self.nickname, self.id) + + +class Friendship(db.Model): + __tablename__ = "gino_friendship" + + my_id = db.Column(db.BigInteger(), primary_key=True) + friend_id = db.Column(db.BigInteger(), primary_key=True) + + def __repr__(self): + return "Friends<{}, {}>".format(self.my_id, self.friend_id) + + +class Relation(db.Model): + __tablename__ = "gino_relation" + + name = db.Column(db.VARCHAR(255), primary_key=True) + + +class Team(db.Model): + __tablename__ = "gino_teams" + + id = db.Column(db.BigInteger(), primary_key=True) + name = db.Column(db.Unicode(255), default=_random_name) + parent_id = db.Column(db.ForeignKey("gino_teams.id", ondelete='CASCADE')) + company_id = db.Column(db.ForeignKey("gino_companies.id")) + + def __init__(self, **kw): + super().__init__(**kw) + self._members = set() + + @property + def members(self): + return self._members + + @members.setter + def add_member(self, user): + self._members.add(user) + + +class TeamWithDefaultCompany(Team): + company = Team(name="DEFAULT") + + +class TeamWithoutMembersSetter(Team): + def add_member(self, user): + self._members.add(user) + + +class Company(db.Model): + __tablename__ = "gino_companies" + + id = db.Column(db.BigInteger(), primary_key=True) + name = db.Column(db.Unicode(255), default=_random_name) + logo = db.Column(db.LargeBinary()) + + def __init__(self, **kw): + super().__init__(**kw) + self._teams = set() + + @property + def teams(self): + return self._teams + + @teams.setter + def add_team(self, team): + self._teams.add(team) + + +class CompanyWithoutTeamsSetter(Company): + def add_team(self, team): + self._teams.add(team) + + +class UserSetting(db.Model): + __tablename__ = "gino_user_settings" + + # No constraints defined on columns + id = db.Column(db.BigInteger()) + user_id = db.Column(db.BigInteger()) + setting = db.Column(db.VARCHAR(255)) + value = db.Column(db.Text()) + col1 = db.Column(db.Integer, default=1) + col2 = db.Column(db.Integer, default=2) + + # Define indexes and constraints inline + id_pkey = db.PrimaryKeyConstraint("id") + user_id_fk = db.ForeignKeyConstraint(["user_id"], ["gino_users.id"]) + user_id_setting_unique = db.UniqueConstraint("user_id", "setting") + col1_check = db.CheckConstraint("col1 >= 1 AND col1 <= 5") + col2_idx = db.Index("col2_idx", "col2") + + +def qsize(engine): + if isinstance(engine.raw_pool, aiomysql.pool.Pool): + return engine.raw_pool.freesize + if isinstance(engine.raw_pool, asyncpg.pool.Pool): + # noinspection PyProtectedMember + return engine.raw_pool._queue.qsize() + raise Exception('Unknown pool') diff --git a/mysql_tests/test_bakery.py b/mysql_tests/test_bakery.py new file mode 100644 index 00000000..5be672b3 --- /dev/null +++ b/mysql_tests/test_bakery.py @@ -0,0 +1,162 @@ +import pytest +import sqlalchemy + +from gino import UninitializedError, create_engine, InitializedError +from gino.bakery import Bakery, BakedQuery +from .models import db, User, MYSQL_URL + +pytestmark = pytest.mark.asyncio + + +@pytest.mark.parametrize( + "query", + [ + User.query.where(User.id == db.bindparam("uid")), + sqlalchemy.text("SELECT * FROM gino_users WHERE id = :uid"), + "SELECT * FROM gino_users WHERE id = :uid", + lambda: User.query.where(User.id == db.bindparam("uid")), + lambda: sqlalchemy.text("SELECT * FROM gino_users WHERE id = :uid"), + lambda: "SELECT * FROM gino_users WHERE id = :uid", + ], +) +@pytest.mark.parametrize("options", [dict(return_model=False), dict(loader=User)]) +@pytest.mark.parametrize("api", [True, False]) +@pytest.mark.parametrize("timeout", [None, 1]) +async def test(query, options, sa_engine, api, timeout): + uid = sa_engine.execute(User.insert()).lastrowid + if timeout: + options["timeout"] = timeout + + if api: + b = db._bakery + qs = [db.bake(query, **options)] + if callable(query): + qs.append(db.bake(**options)(query)) + else: + b = Bakery() + qs = [b.bake(query, **options)] + if callable(query): + qs.append(b.bake(**options)(query)) + + for q in qs: + assert isinstance(q, BakedQuery) + assert q in list(b) + assert q.sql is None + assert q.compiled_sql is None + + with pytest.raises(UninitializedError): + q.bind.first() + with pytest.raises(UninitializedError): + await q.first() + + for k, v in options.items(): + assert q.query.get_execution_options()[k] == v + + if api: + e = await db.set_bind(MYSQL_URL, minsize=1) + else: + e = await create_engine(MYSQL_URL, bakery=b, minsize=1) + + with pytest.raises(InitializedError): + b.bake("SELECT now()") + + with pytest.raises(InitializedError): + await create_engine(MYSQL_URL, bakery=b, minsize=0) + + try: + for q in qs: + assert q.sql is not None + assert q.compiled_sql is not None + + if api: + assert q.bind is e + else: + with pytest.raises(UninitializedError): + q.bind.first() + with pytest.raises(UninitializedError): + await q.first() + + if api: + rv = await q.first(uid=uid) + else: + rv = await e.first(q, uid=uid) + + if options.get("return_model", True): + assert isinstance(rv, User) + assert rv.id == uid + else: + assert rv[0] == rv[User.id] == rv["id"] == uid + + eq = q.execution_options(return_model=True, loader=User) + assert eq is not q + assert isinstance(eq, BakedQuery) + assert type(eq) is not BakedQuery + assert eq in list(b) + assert eq.sql == q.sql + assert eq.compiled_sql is not q.compiled_sql + + if api: + assert q.bind is e + else: + with pytest.raises(UninitializedError): + eq.bind.first() + with pytest.raises(UninitializedError): + await eq.first() + + assert eq.query.get_execution_options()["return_model"] + assert eq.query.get_execution_options()["loader"] is User + + if api: + rv = await eq.first(uid=uid) + non = await eq.first(uid=uid + 1) + rvl = await eq.all(uid=uid) + else: + rv = await e.first(eq, uid=uid) + non = await e.first(eq, uid=uid + 1) + rvl = await e.all(eq, uid=uid) + + assert isinstance(rv, User) + assert rv.id == uid + + assert non is None + + assert len(rvl) == 1 + assert rvl[0].id == uid + + # original query is not affected + if api: + rv = await q.first(uid=uid) + else: + rv = await e.first(q, uid=uid) + + if options.get("return_model", True): + assert isinstance(rv, User) + assert rv.id == uid + else: + assert rv[0] == rv[User.id] == rv["id"] == uid + + finally: + if api: + await db.pop_bind().close() + else: + await e.close() + + +async def test_class_level_bake(): + class BakeOnClass(db.Model): + __tablename__ = "bake_on_class_test" + + name = db.Column(db.String(255), primary_key=True) + + @db.bake + def getter(cls): + return cls.query.where(cls.name == db.bindparam("name")) + + async with db.with_bind(MYSQL_URL, prebake=False): + await db.gino.create_all() + try: + await BakeOnClass.create(name="exist") + assert (await BakeOnClass.getter.one(name="exist")).name == "exist" + assert (await BakeOnClass.getter.one_or_none(name="nonexist")) is None + finally: + await db.gino.drop_all() diff --git a/mysql_tests/test_bind.py b/mysql_tests/test_bind.py new file mode 100644 index 00000000..61ae78bc --- /dev/null +++ b/mysql_tests/test_bind.py @@ -0,0 +1,63 @@ +import random + +import pytest +from gino.exceptions import UninitializedError +from sqlalchemy.engine.url import make_url + +from .models import db, MYSQL_URL, User + +pytestmark = pytest.mark.asyncio + + +# noinspection PyUnusedLocal +async def test_create(bind): + nickname = "test_create_{}".format(random.random()) + u = await User.create(nickname=nickname) + assert u.id is not None + assert u.nickname == nickname + return u + + +async def test_get(bind): + u1 = await test_create(bind) + u2 = await User.get(u1.id) + assert u1.id == u2.id + assert u1.nickname == u2.nickname + assert u1 is not u2 + + +# noinspection PyUnusedLocal +async def test_unbind(aiomysql_pool): + await db.set_bind(MYSQL_URL) + await test_create(None) + await db.pop_bind().close() + db.bind = None + with pytest.raises(UninitializedError): + await test_create(None) + # test proper exception when engine is not initialized + with pytest.raises(UninitializedError): + db.bind.first = lambda x: 1 + + +async def test_db_api(bind, random_name): + result = await db.first(User.insert().values(name=random_name)) + assert result is None + r = await db.scalar(User.select('nickname').where(User.nickname == random_name)) + assert r == random_name + assert ( + await db.first(User.query.where(User.nickname == random_name)) + ).nickname == random_name + assert len(await db.all(User.query.where(User.nickname == random_name))) == 1 + assert (await db.status(User.delete.where(User.nickname == random_name)))[ + 0 + ] == 1 + stmt, params = db.compile(User.query.where(User.id == 3)) + assert params[0] == 3 + + +async def test_bind_url(): + url = make_url(MYSQL_URL) + assert url.drivername == "mysql" + await db.set_bind(MYSQL_URL) + assert url.drivername == "mysql" + await db.pop_bind().close() diff --git a/mysql_tests/test_core.py b/mysql_tests/test_core.py new file mode 100644 index 00000000..39809086 --- /dev/null +++ b/mysql_tests/test_core.py @@ -0,0 +1,103 @@ +import pytest +from sqlalchemy import Table, Column, Integer, String, MetaData, ForeignKey +from sqlalchemy.engine.result import RowProxy + +from .models import MYSQL_URL + +pytestmark = pytest.mark.asyncio + + +async def test_engine_only(): + import gino + from gino.schema import GinoSchemaVisitor + + metadata = MetaData() + + users = Table( + "users", + metadata, + Column("id", Integer, primary_key=True), + Column("name", String(255)), + Column("fullname", String(255)), + ) + + Table( + "addresses", + metadata, + Column("id", Integer, primary_key=True), + Column("user_id", None, ForeignKey("users.id")), + Column("email_address", String(255), nullable=False), + ) + + engine = await gino.create_engine(MYSQL_URL) + await GinoSchemaVisitor(metadata).create_all(engine) + try: + ins = users.insert().values(name="jack", fullname="Jack Jones") + await engine.status(ins) + res = await engine.all(users.select()) + assert isinstance(res[0], RowProxy) + finally: + await GinoSchemaVisitor(metadata).drop_all(engine) + await engine.close() + + +async def test_core(): + from gino import Gino + + db = Gino() + + users = db.Table( + "users", + db, + db.Column("id", db.Integer, primary_key=True), + db.Column("name", db.String(255)), + db.Column("fullname", db.String(255)), + ) + + db.Table( + "addresses", + db, + db.Column("id", db.Integer, primary_key=True), + db.Column("user_id", None, db.ForeignKey("users.id")), + db.Column("email_address", db.String(255), nullable=False), + ) + + async with db.with_bind(MYSQL_URL): + await db.gino.create_all() + try: + await users.insert().values( + name="jack", fullname="Jack Jones", + ).gino.status() + res = await users.select().gino.all() + assert isinstance(res[0], RowProxy) + finally: + await db.gino.drop_all() + + +async def test_orm(): + from gino import Gino + + db = Gino() + + class User(db.Model): + __tablename__ = "users" + + id = db.Column(db.Integer, primary_key=True) + name = db.Column(db.String(255)) + fullname = db.Column(db.String(255)) + + class Address(db.Model): + __tablename__ = "addresses" + + id = db.Column(db.Integer, primary_key=True) + user_id = db.Column(None, db.ForeignKey("users.id")) + email_address = db.Column(db.String(255), nullable=False) + + async with db.with_bind(MYSQL_URL): + await db.gino.create_all() + try: + await User.create(name="jack", fullname="Jack Jones") + res = await User.query.gino.all() + assert isinstance(res[0], User) + finally: + await db.gino.drop_all() diff --git a/mysql_tests/test_crud.py b/mysql_tests/test_crud.py new file mode 100644 index 00000000..0f4f4775 --- /dev/null +++ b/mysql_tests/test_crud.py @@ -0,0 +1,333 @@ +import random + +import pytest + +from .models import db, User, UserType, Friendship, Relation, MYSQL_URL + +pytestmark = pytest.mark.asyncio + + +async def test_create(engine): + nickname = "test_create_{}".format(random.random()) + u = await User.create( + bind=engine, timeout=10, nickname=nickname, age=42, type=UserType.USER + ) + assert u.id is not None + assert u.nickname == nickname + assert u.type == UserType.USER + assert u.age == 42 + + u2 = await User.get(u.id, bind=engine, timeout=10) + assert u2.id == u.id + assert u2.nickname == nickname + assert u2.type == UserType.USER + assert u2.age == 42 + assert u2 is not u + + return u + + +async def test_create_from_instance(engine): + nickname = "test_create_from_instance_{}".format(random.random()) + u = User(nickname="will-be-replaced", type=UserType.USER, age=42) + u.nickname = nickname + u.age = 21 + await u.create(bind=engine, timeout=10) + assert u.id is not None + assert u.nickname == nickname + assert u.type == UserType.USER + assert u.age == 21 + + u2 = await User.get(u.id, bind=engine, timeout=10) + assert u2.id == u.id + assert u2.nickname == nickname + assert u2.type == UserType.USER + assert u2.age == 21 + assert u2 is not u + + return u + + +async def test_get(engine): + u1 = await test_create(engine) + u2 = await User.get(u1.id, bind=engine, timeout=10) + assert u1.id == u2.id + assert u1.nickname == u2.nickname + assert u1 is not u2 + + u3 = await engine.first(u1.query) + assert u1.id == u3.id + assert u1.nickname == u3.nickname + assert u1 is not u3 + + u4 = await test_create_from_instance(engine) + u5 = await engine.first(u4.query) + assert u4.id == u5.id + assert u4.nickname == u5.nickname + assert u4 is not u5 + + +async def test_textual_sql(engine): + u1 = await test_create(engine) + u2 = await engine.first( + db.text("SELECT * FROM gino_users WHERE id = :uid") + .bindparams(uid=u1.id) + .columns(*User) + .execution_options(model=User) + ) + assert isinstance(u2, User) + assert u1.id == u2.id + assert u1.nickname == u2.nickname + assert u1.type is u2.type + assert u1 is not u2 + + u2 = await engine.first( + db.text("SELECT * FROM gino_users WHERE id = :uid AND type = :utype") + .bindparams(db.bindparam("utype", type_=db.Enum(UserType))) + .bindparams(uid=u1.id, utype=UserType.USER,) + .columns(*User) + .execution_options(model=User) + ) + assert isinstance(u2, User) + assert u1.id == u2.id + assert u1.nickname == u2.nickname + assert u1.type is u2.type + assert u1 is not u2 + + +async def test_select(engine): + u = await test_create(engine) + name = await engine.scalar(User.select("nickname").where(User.id == u.id)) + assert u.nickname == name + + name = await engine.scalar(u.select("nickname")) + assert u.nickname == name + + +async def test_get_multiple_primary_key(engine): + u1 = await test_create(engine) + u2 = await test_create(engine) + await Friendship.create(bind=engine, my_id=u1.id, friend_id=u2.id) + with pytest.raises(ValueError, match="Incorrect number of values as primary key"): + await Friendship.get((u1.id,), bind=engine) + with pytest.raises(ValueError, match="Incorrect number of values as primary key"): + await Friendship.get(u1.id, bind=engine) + f = await Friendship.get((u1.id, u2.id), bind=engine) + assert f + assert f.my_id == u1.id + assert f.friend_id == u2.id + + +async def test_multiple_primary_key_order(): + import gino + + db1 = gino.Gino() + await db1.set_bind(MYSQL_URL) + + class NameCard(db1.Model): + __tablename__ = "name_cards" + + first_name = db1.Column(db1.Unicode(255), primary_key=True) + last_name = db1.Column(db1.Unicode(255), primary_key=True) + + await db1.gino.create_all() + + try: + await NameCard.create(first_name="first", last_name="last") + nc = await NameCard.get(("first", "last")) + assert nc.first_name == "first" + assert nc.last_name == "last" + with pytest.raises(ValueError, match="expected 2, got 3"): + await NameCard.get(dict(a=1, first_name="first", last_name="last")) + with pytest.raises(KeyError, match="first_name"): + await NameCard.get(dict(first="first", last_name="last")) + nc = await NameCard.get(dict(first_name="first", last_name="last")) + assert nc.first_name == "first" + assert nc.last_name == "last" + nc = await NameCard.get({0: "first", 1: "last"}) + assert nc.first_name == "first" + assert nc.last_name == "last" + finally: + await db1.gino.drop_all() + await db1.pop_bind().close() + + db2 = gino.Gino(MYSQL_URL) + await db2.set_bind(MYSQL_URL) + + class NameCard(db2.Model): + __tablename__ = "name_cards" + + last_name = db2.Column(db2.Unicode(255), primary_key=True) + first_name = db2.Column(db2.Unicode(255), primary_key=True) + + await db2.gino.create_all() + + try: + await NameCard.create(first_name="first", last_name="last") + nc = await NameCard.get(("last", "first")) + assert nc.first_name == "first" + assert nc.last_name == "last" + nc = await NameCard.get(dict(first_name="first", last_name="last")) + assert nc.first_name == "first" + assert nc.last_name == "last" + nc = await NameCard.get({1: "first", "last_name": "last"}) + assert nc.first_name == "first" + assert nc.last_name == "last" + finally: + await db2.gino.drop_all() + await db2.pop_bind().close() + + +async def test_connection_as_bind(engine): + async with engine.acquire() as conn: + await test_get(conn) + + +async def test_update(engine, random_name): + u1 = await test_create(engine) + await u1.update(nickname=random_name).apply(bind=engine, timeout=10) + u2 = await User.get(u1.id, bind=engine) + assert u2.nickname == random_name + + +async def test_update_missing(engine, random_name): + from gino.exceptions import NoSuchRowError + + u1 = await test_create(engine) + rq = u1.update(nickname=random_name) + await u1.delete(bind=engine) + with pytest.raises(NoSuchRowError): + await rq.apply(bind=engine, timeout=10) + + +async def test_update_multiple_primary_key(engine): + u1 = await test_create(engine) + u2 = await test_create(engine) + u3 = await test_create(engine) + await Friendship.create(bind=engine, my_id=u1.id, friend_id=u2.id) + f = await Friendship.get((u1.id, u2.id), bind=engine) + await f.update(my_id=u2.id, friend_id=u3.id).apply(bind=engine) + f2 = await Friendship.get((u2.id, u3.id), bind=engine) + assert f2 + + +async def test_delete(engine): + u1 = await test_create(engine) + await u1.delete(bind=engine, timeout=10) + u2 = await User.get(u1.id, bind=engine) + assert not u2 + + +async def test_delete_bind(bind): + u1 = await test_create(bind) + await u1.delete(timeout=10) + u2 = await User.get(u1.id) + assert not u2 + + +async def test_delete_multiple_primary_key(engine): + u1 = await test_create(engine) + u2 = await test_create(engine) + f = await Friendship.create(bind=engine, my_id=u1.id, friend_id=u2.id) + await f.delete(bind=engine) + f2 = await Friendship.get((u1.id, u2.id), bind=engine) + assert not f2 + + +async def test_string_primary_key(engine): + relations = ["Colleagues", "Friends", "Lovers"] + for r in relations: + await Relation.create(bind=engine, timeout=10, name=r) + r1 = await Relation.get(relations[0], bind=engine, timeout=10) + assert r1.name == relations[0] + + +async def test_lookup_287(bind): + from gino.exceptions import NoSuchRowError + + class Game(db.Model): + __tablename__ = "games" + game_id = db.Column(db.String(32), unique=True) + channel_id = db.Column(db.String(1), default="A") + + await Game.gino.create() + try: + game_1 = await Game.create(game_id="1", channel_id="X") + game_2 = await Game.create(game_id="2", channel_id="Y") + + # ordinary update should be fine + uq = game_1.update(game_id="3") + + with pytest.raises(TypeError, match="Model Game has no table, primary key"): + # but applying the updates to DB should fail + await uq.apply() + + with pytest.raises( + LookupError, match="Instance-level CRUD operations not allowed" + ): + await game_2.delete() + with pytest.raises( + LookupError, match="Instance-level CRUD operations not allowed" + ): + await game_2.query.gino.all() + with pytest.raises( + LookupError, match="Instance-level CRUD operations not allowed" + ): + await game_2.select("game_id") + + # previous ordinary update still in effect + assert game_1.game_id == "3" + + assert await Game.select("game_id").gino.all() == [("1",), ("2",)] + + Game.lookup = lambda self: Game.game_id == self.game_id + with pytest.raises(NoSuchRowError): + await game_1.update(channel_id="Z").apply() + await game_2.update(channel_id="Z").apply() + assert await Game.select("channel_id").gino.all() == [("X",), ("Z",)] + finally: + await Game.gino.drop() + + +async def test_lookup_custom_name(bind): + class ModelWithCustomColumnNames(db.Model): + __tablename__ = "gino_test_custom_column_names" + + id = db.Column("other", db.Integer(), primary_key=True) + field = db.Column(db.Text()) + + await ModelWithCustomColumnNames.gino.create() + + try: + # create + m1 = await ModelWithCustomColumnNames.create(id=1, field="A") + m2 = await ModelWithCustomColumnNames.create(id=2, field="B") + + # update + uq = m1.update(field="C") + await uq.apply() + + # lookup + assert set( + tuple(x) for x in await ModelWithCustomColumnNames.select("id").gino.all() + ) == {(1,), (2,)} + assert (await ModelWithCustomColumnNames.get(2)).field == "B" + assert (await ModelWithCustomColumnNames.get(1)).field == "C" + assert await ModelWithCustomColumnNames.get(3) is None + + # delete + assert ( + await ModelWithCustomColumnNames.delete.where( + ModelWithCustomColumnNames.id == 3 + ).gino.status() + )[0] == 0 + assert ( + await ModelWithCustomColumnNames.delete.where( + ModelWithCustomColumnNames.id == 2 + ).gino.status() + )[0] == 1 + assert set( + tuple(x) for x in await ModelWithCustomColumnNames.select("id").gino.all() + ) == {(1,)} + finally: + await ModelWithCustomColumnNames.gino.drop() diff --git a/mysql_tests/test_declarative.py b/mysql_tests/test_declarative.py new file mode 100644 index 00000000..0aad6334 --- /dev/null +++ b/mysql_tests/test_declarative.py @@ -0,0 +1,305 @@ +import pytest + +import gino +from gino.declarative import InvertDict +from aiomysql import IntegrityError + +from .models import User, UserSetting + +pytestmark = pytest.mark.asyncio +db = gino.Gino() + + +# noinspection PyUnusedLocal +async def test_column_not_deletable(bind): + u = await User.create(nickname="test") + with pytest.raises(AttributeError): + del u.nickname + + +async def test_table_args(): + class Model(db.Model): + __tablename__ = "model1" + + assert Model.__table__.implicit_returning + + class Model(db.Model): + __tablename__ = "model2" + + __table_args__ = dict(implicit_returning=False) + + assert not Model.__table__.implicit_returning + + class Model(db.Model): + __tablename__ = "model3" + + __table_args__ = db.Column("new_col"), dict(implicit_returning=False) + + assert not Model.__table__.implicit_returning + assert not hasattr(Model, "new_col") + assert not hasattr(Model.__table__.c, "nonexist") + assert hasattr(Model.__table__.c, "new_col") + + class Model(db.Model): + __tablename__ = "model4" + __table_args__ = db.Column("col1"), db.Column("col2") + + col3 = db.Column() + + assert not hasattr(Model, "col1") + assert not hasattr(Model, "col2") + assert hasattr(Model, "col3") + assert hasattr(Model.__table__.c, "col1") + assert hasattr(Model.__table__.c, "col2") + assert hasattr(Model.__table__.c, "col3") + + class Model(db.Model): + @db.declared_attr + def __tablename__(cls): + return "model5" + + assert Model.__table__.name == "model5" + + +async def test_inline_constraints_and_indexes(bind, engine): + u = await User.create(nickname="test") + us1 = await UserSetting.create(user_id=u.id, setting="skin", value="blue") + + # PrimaryKeyConstraint + with pytest.raises(IntegrityError): + await UserSetting.create(id=us1.id, user_id=u.id, setting="key1", value="val1") + + # ForeignKeyConstraint + with pytest.raises(IntegrityError): + await UserSetting.create(user_id=42, setting="key2", value="val2") + + # UniqueConstraint + with pytest.raises(IntegrityError): + await UserSetting.create( + user_id=u.id, setting="skin", value="duplicate-setting" + ) + + # MySQL doesn't support CheckConstraint + # with pytest.raises(CheckViolationError): + # await UserSetting.create(user_id=u.id, setting="key3", value="val3", col1=42) + + # Index + status, result = await engine.status( + "SHOW INDEXES FROM `gino_user_settings` WHERE Key_name = 'col2_idx'" + ) + assert status == 1 + + +async def test_join_t112(engine): + class Car(db.Model): + __tablename__ = "cars" + + id = db.Column(db.BigInteger(), primary_key=True) + + class Wheel(db.Model): + __tablename__ = "wheels" + + id = db.Column(db.BigInteger(), primary_key=True) + car_id = db.Column(db.ForeignKey("cars.id")) + + sql = ( + "SELECT wheels.id, wheels.car_id, cars.id \nFROM wheels " + "INNER JOIN cars ON cars.id = wheels.car_id" + ) + + assert engine.compile(Wheel.join(Car).select())[0] == sql + + +async def test_mixin(): + class Tracked: + created = db.Column(db.DateTime(timezone=True)) + + @db.declared_attr + def unique_id(cls): + return db.Column(db.Integer()) + + @db.declared_attr + def unique_constraint(cls): + return db.UniqueConstraint("unique_id") + + @db.declared_attr + def poly(cls): + if cls.__name__ == "Thing": + return db.Column(db.Unicode()) + + @db.declared_attr + def __table_args__(cls): + if cls.__name__ == "Thing": + return (db.UniqueConstraint("poly"),) + + class Audit(Tracked): + pass + + class Thing(Audit, db.Model): + __tablename__ = "thing" + + id = db.Column(db.Integer, primary_key=True) + + class Another(Audit, db.Model): + __tablename__ = "another" + + id = db.Column(db.Integer, primary_key=True) + + assert isinstance(Thing.__table__.c.created, db.Column) + assert isinstance(Another.__table__.c.created, db.Column) + assert Thing.created is not Another.created + assert Thing.created is Thing.__table__.c.created + assert Another.created is Another.__table__.c.created + + assert Thing.unique_id is not Another.unique_id + assert Thing.unique_id is Thing.__table__.c.unique_id + c1, c2 = [ + list( + filter( + lambda c: list(c.columns)[0].name == "unique_id", + m.__table__.constraints, + ) + )[0] + for m in [Thing, Another] + ] + assert isinstance(c1, db.UniqueConstraint) + assert isinstance(c2, db.UniqueConstraint) + assert c1 is not c2 + + assert isinstance(Thing.poly, db.Column) + assert Another.poly is None + for c in Thing.__table__.constraints: + if list(c.columns)[0].name == "poly": + assert isinstance(c, db.UniqueConstraint) + break + else: + assert False, "Should not reach here" + + +# noinspection PyUnusedLocal +async def test_inherit_constraint(): + with pytest.raises(ValueError, match="already attached to another table"): + + class IllegalUserSetting(UserSetting): + __table__ = None + __tablename__ = "bad_gino_user_settings" + + +async def test_abstract_model_error(): + class ConcreteModel(db.Model): + __tablename__ = "some_table" + + c = db.Column(db.Unicode()) + + class AbstractModel(db.Model): + pass + + with pytest.raises(TypeError, match="AbstractModel is abstract"): + ConcreteModel.join(AbstractModel) + + with pytest.raises(TypeError, match="AbstractModel is abstract"): + AbstractModel.join(ConcreteModel) + + with pytest.raises(TypeError, match="AbstractModel is abstract"): + db.select(AbstractModel) + + with pytest.raises(TypeError, match="AbstractModel is abstract"): + db.select([AbstractModel]) + + with pytest.raises(TypeError, match="AbstractModel is abstract"): + # noinspection PyStatementEffect + AbstractModel.query + + with pytest.raises(TypeError, match="AbstractModel is abstract"): + # noinspection PyStatementEffect + AbstractModel.update + + am = AbstractModel() + + with pytest.raises(TypeError, match="AbstractModel is abstract"): + await am.create() + + with pytest.raises(TypeError, match="AbstractModel is abstract"): + await am.delete() + + req = am.update() + + with pytest.raises(TypeError, match="AbstractModel has no table"): + await req.apply() + + with pytest.raises(TypeError, match="AbstractModel is abstract"): + # noinspection PyStatementEffect + AbstractModel.delete + + with pytest.raises(TypeError, match="AbstractModel is abstract"): + AbstractModel.alias() + + with pytest.raises(TypeError, match="AbstractModel is abstract"): + AbstractModel.alias() + + with pytest.raises(TypeError, match="AbstractModel is abstract"): + await AbstractModel.get(1) + + +async def test_invert_dict(): + with pytest.raises(gino.GinoException, match=r"Column name c1 already maps to \w+"): + InvertDict({"col1": "c1", "col2": "c1"}) + + with pytest.raises(gino.GinoException, match=r"Column name c1 already maps to \w+"): + d = InvertDict() + d["col1"] = "c1" + d["col2"] = "c1" + + d = InvertDict() + d["col1"] = "c1" + # it works for same key/value pair + d["col1"] = "c1" + d["col2"] = "c2" + assert d.invert_get("c1") == "col1" + assert d.invert_get("c2") == "col2" + + +async def test_instant_column_name(): + class Model(db.Model): + user = db.Column() + assert user.name == "user" + + select_col = db.Column(name=db.quoted_name("select", False)) + assert select_col.name == "select" + assert not select_col.name.quote + + +async def test_overwrite_declared_table_name(): + class MyTableNameMixin: + @db.declared_attr + def __tablename__(cls): + return cls.__name__.lower() + + class MyTableWithoutName(MyTableNameMixin, db.Model): + id = db.Column(db.Integer, primary_key=True) + + class MyTableWithName(MyTableNameMixin, db.Model): + __tablename__ = "manually_overwritten_name" + id = db.Column(db.Integer, primary_key=True) + + assert MyTableWithoutName.__table__.name == "mytablewithoutname" + assert MyTableWithName.__table__.name == "manually_overwritten_name" + + +async def test_multiple_inheritance_overwrite_declared_table_name(): + class MyTableNameMixin: + @db.declared_attr + def __tablename__(cls): + return cls.__name__.lower() + + class AnotherTableNameMixin: + __tablename__ = "static_table_name" + + class MyTableWithoutName(AnotherTableNameMixin, MyTableNameMixin, db.Model): + id = db.Column(db.Integer, primary_key=True) + + class MyOtherTableWithoutName(MyTableNameMixin, AnotherTableNameMixin, db.Model): + id = db.Column(db.Integer, primary_key=True) + + assert MyTableWithoutName.__table__.name == "static_table_name" + assert MyOtherTableWithoutName.__table__.name == "myothertablewithoutname" diff --git a/mysql_tests/test_dialect.py b/mysql_tests/test_dialect.py new file mode 100644 index 00000000..b5350ed3 --- /dev/null +++ b/mysql_tests/test_dialect.py @@ -0,0 +1,9 @@ +import pytest +from .models import Company + +pytestmark = pytest.mark.asyncio + + +async def test_225_large_binary(bind): + c = await Company.create(logo=b"SVG LOGO") + assert c.logo == b"SVG LOGO" diff --git a/mysql_tests/test_engine.py b/mysql_tests/test_engine.py new file mode 100644 index 00000000..06df9a2b --- /dev/null +++ b/mysql_tests/test_engine.py @@ -0,0 +1,403 @@ +import asyncio +import logging +from datetime import datetime + +import pymysql +import aiomysql +from gino import create_engine, UninitializedError +import pytest +from sqlalchemy.exc import ObjectNotExecutableError +import sqlalchemy as sa + +from .models import db, User, MYSQL_URL, qsize + +pytestmark = pytest.mark.asyncio + + +async def test_basic(engine): + init_size = qsize(engine) + async with engine.acquire() as conn: + assert isinstance(conn.raw_connection, aiomysql.Connection) + assert init_size == qsize(engine) + assert isinstance(await engine.scalar("select now()"), datetime) + assert isinstance(await engine.scalar(sa.text("select now()")), datetime) + assert isinstance((await engine.first("select now()"))[0], datetime) + assert isinstance((await engine.all("select now()"))[0][0], datetime) + assert isinstance((await engine.one("select now()"))[0], datetime) + assert isinstance((await engine.one_or_none("select now()"))[0], datetime) + status, result = await engine.status("select now()") + assert status == 1 + assert isinstance(result[0][0], datetime) + with pytest.raises(ObjectNotExecutableError): + await engine.all(object()) + + +async def test_issue_79(): + e = await create_engine(MYSQL_URL + "_non_exist", minsize=0) + with pytest.raises(pymysql.err.OperationalError): + async with e.acquire(): + pass # pragma: no cover + # noinspection PyProtectedMember + assert len(e._ctx.get([])) == 0 + + +async def test_reuse(engine): + init_size = qsize(engine) + async with engine.acquire(reuse=True) as conn1: + assert qsize(engine) == init_size - 1 + async with engine.acquire(reuse=True) as conn2: + assert qsize(engine) == init_size - 1 + assert conn1.raw_connection is conn2.raw_connection + assert await engine.scalar("select now()") + assert qsize(engine) == init_size - 1 + assert qsize(engine) == init_size - 1 + assert qsize(engine) == init_size + + async with engine.acquire(reuse=False) as conn1: + assert qsize(engine) == init_size - 1 + async with engine.acquire(reuse=True) as conn2: + assert qsize(engine) == init_size - 1 + assert conn1.raw_connection is conn2.raw_connection + assert qsize(engine) == init_size - 1 + assert qsize(engine) == init_size + + async with engine.acquire(reuse=True) as conn1: + assert qsize(engine) == init_size - 1 + async with engine.acquire(reuse=False) as conn2: + assert qsize(engine) == init_size - 2 + assert conn1.raw_connection is not conn2.raw_connection + assert qsize(engine) == init_size - 1 + assert qsize(engine) == init_size + + async with engine.acquire(reuse=False) as conn1: + assert qsize(engine) == init_size - 1 + async with engine.acquire(reuse=False) as conn2: + assert qsize(engine) == init_size - 2 + assert conn1.raw_connection is not conn2.raw_connection + assert qsize(engine) == init_size - 1 + assert qsize(engine) == init_size + + async with engine.acquire(reuse=False) as conn1: + assert qsize(engine) == init_size - 1 + async with engine.acquire(reuse=True) as conn2: + assert qsize(engine) == init_size - 1 + assert conn1.raw_connection is conn2.raw_connection + async with engine.acquire(reuse=False) as conn3: + assert qsize(engine) == init_size - 2 + assert conn1.raw_connection is not conn3.raw_connection + async with engine.acquire(reuse=True) as conn4: + assert qsize(engine) == init_size - 2 + assert conn3.raw_connection is conn4.raw_connection + assert qsize(engine) == init_size - 2 + assert qsize(engine) == init_size - 1 + assert qsize(engine) == init_size - 1 + assert qsize(engine) == init_size + + +async def test_compile(engine): + stmt, params = engine.compile(User.query.where(User.id == 3)) + assert params[0] == 3 + + +async def test_logging(mocker): + mocker.patch("logging.Logger._log") + sql = "SELECT NOW() AS test_logging" + + e = await create_engine(MYSQL_URL, echo=False) + await e.scalar(sql) + await e.close() + # noinspection PyProtectedMember,PyUnresolvedReferences + logging.Logger._log.assert_not_called() + + e = await create_engine(MYSQL_URL, echo=True) + await e.scalar(sql) + await e.close() + # noinspection PyProtectedMember,PyUnresolvedReferences + logging.Logger._log.assert_any_call(logging.INFO, sql, ()) + + +async def test_set_isolation_level(): + e = await create_engine(MYSQL_URL, isolation_level="non") + with pytest.raises(sa.exc.ArgumentError): + await e.acquire() + await e.close() + e = await create_engine(MYSQL_URL, isolation_level="READ_UNCOMMITTED") + async with e.acquire() as conn: + assert ( + await e.dialect.get_isolation_level(conn.raw_connection) + == "READ UNCOMMITTED" + ) + async with e.transaction(isolation="SERIALIZABLE") as tx: + assert ( + await e.dialect.get_isolation_level(tx.connection.raw_connection) + == "SERIALIZABLE" + ) + await e.close() + + +async def test_too_many_engine_args(): + with pytest.raises(TypeError): + await create_engine(MYSQL_URL, non_exist=None) + + +# noinspection PyUnusedLocal +async def test_scalar_return_none(bind): + assert await User.query.where(User.nickname == "nonexist").gino.scalar() is None + + +async def test_async_metadata(): + import gino + + db_ = await gino.Gino(MYSQL_URL) + assert isinstance((await db_.scalar("select now()")), datetime) + await db_.pop_bind().close() + with pytest.raises(UninitializedError): + db.bind.first() + + +# noinspection PyUnreachableCode +async def test_acquire_timeout(): + e = await create_engine(MYSQL_URL, minsize=1, maxsize=1) + async with e.acquire() as x: + with pytest.raises(asyncio.TimeoutError): + async with e.acquire(timeout=0.1) as y: + assert False, "Should not reach here" + + loop = asyncio.get_event_loop() + f1 = loop.create_future() + + async def first(): + async with e.acquire() as conn: + f1.set_result(None) + await asyncio.sleep(0.2) + # noinspection PyProtectedMember + return conn.raw_connection + + async def second(): + async with e.acquire(lazy=True) as conn: + conn = conn.execution_options(timeout=0.1) + with pytest.raises(asyncio.TimeoutError): + assert await conn.scalar("select 1") + + async def third(): + async with e.acquire(reuse=True, timeout=0.4) as conn: + # noinspection PyProtectedMember + return conn.raw_connection + + t1 = loop.create_task(first()) + await f1 + loop.create_task(second()) + t3 = loop.create_task(third()) + assert await t1 is await t3 + await e.close() + + +# noinspection PyProtectedMember +async def test_lazy(mocker): + engine = await create_engine(MYSQL_URL, minsize=1, maxsize=1) + init_size = qsize(engine) + async with engine.acquire(lazy=True): + assert qsize(engine) == init_size + assert len(engine._ctx.get()) == 1 + assert engine._ctx.get() is None + assert qsize(engine) == init_size + async with engine.acquire(lazy=True): + assert qsize(engine) == init_size + assert len(engine._ctx.get()) == 1 + assert await engine.scalar("select 1") + assert qsize(engine) == init_size - 1 + assert len(engine._ctx.get()) == 1 + assert engine._ctx.get() is None + assert qsize(engine) == init_size + + loop = asyncio.get_event_loop() + fut = loop.create_future() + + async def block(): + async with engine.acquire(): + fut.set_result(None) + await asyncio.sleep(0.3) + + blocker = loop.create_task(block()) + await fut + init_size_2 = qsize(engine) + ctx = engine.acquire(lazy=True) + conn = await ctx.__aenter__() + t1 = loop.create_task(conn.execution_options(timeout=0.1).scalar("select 1")) + t2 = loop.create_task(ctx.__aexit__(None, None, None)) + with pytest.raises(asyncio.TimeoutError): + await t1 + assert not await t2 + assert qsize(engine) == init_size_2 + await blocker + assert qsize(engine) == init_size + + fut = loop.create_future() + blocker = loop.create_task(block()) + await fut + init_size_2 = qsize(engine) + + async def acquire_failed(*args, **kwargs): + await asyncio.sleep(0.1) + raise ValueError() + + mocker.patch("aiomysql.pool.Pool.acquire", new=acquire_failed) + ctx = engine.acquire(lazy=True) + conn = await ctx.__aenter__() + t1 = loop.create_task(conn.scalar("select 1")) + t2 = loop.create_task(conn.release(permanent=False)) + with pytest.raises(ValueError): + await t1 + assert not await t2 + assert qsize(engine) == init_size_2 + + await conn.release(permanent=False) + assert qsize(engine) == init_size_2 + + await blocker + assert qsize(engine) == init_size + await engine.close() + + +async def test_release(engine): + init_size = qsize(engine) + async with engine.acquire() as conn: + assert await conn.scalar("select 8") == 8 + await conn.release(permanent=False) + assert await conn.scalar("select 8") == 8 + await conn.release(permanent=False) + with pytest.raises(ValueError, match="released permanently"): + await conn.scalar("select 8") + with pytest.raises(ValueError, match="already released"): + await conn.release() + + conn = await engine.acquire() + assert await conn.scalar("select 8") == 8 + await conn.release(permanent=False) + assert await conn.scalar("select 8") == 8 + await conn.release() + with pytest.raises(ValueError, match="released permanently"): + await conn.scalar("select 8") + + conn1 = await engine.acquire() + conn2 = await engine.acquire(reuse=True) + conn3 = await engine.acquire() + conn4 = await engine.acquire(reuse=True) + assert await conn1.scalar("select 8") == 8 + assert await conn2.scalar("select 8") == 8 + assert await conn3.scalar("select 8") == 8 + assert await conn4.scalar("select 8") == 8 + + await conn1.release(permanent=False) + assert await conn2.scalar("select 8") == 8 + + await conn2.release(permanent=False) + assert await conn2.scalar("select 8") == 8 + + await conn1.release() + with pytest.raises(ValueError, match="released permanently"): + await conn2.scalar("select 8") + assert await conn4.scalar("select 8") == 8 + + await conn4.release() + with pytest.raises(ValueError, match="released permanently"): + await conn4.scalar("select 8") + + assert await conn3.scalar("select 8") == 8 + await conn3.release(permanent=False) + assert await conn3.scalar("select 8") == 8 + assert init_size - 1 == qsize(engine) + await conn3.release(permanent=False) + assert init_size == qsize(engine) + await conn3.release() + assert init_size == qsize(engine) + + conn1 = await engine.acquire() + conn2 = await engine.acquire() + conn3 = await engine.acquire() + assert engine.current_connection is conn3 + await conn2.release() + assert engine.current_connection is conn3 + await conn1.release() + assert engine.current_connection is conn3 + await conn3.release() + assert engine.current_connection is None + assert init_size == qsize(engine) + + +async def test_ssl(): + import ssl + + ctx = ssl.create_default_context() + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + + e = await create_engine(MYSQL_URL, ssl=ctx) + await e.close() + + +async def test_issue_313(bind): + assert bind._ctx.get() is None + + async with db.acquire(): + pass + + assert bind._ctx.get() is None + + async def task(): + async with db.acquire(reuse=True): + await db.scalar("SELECT now()") + + await asyncio.gather(*[task() for _ in range(5)]) + + assert bind._ctx.get() is None + + async def task(): + async with db.transaction(): + await db.scalar("SELECT now()") + + await asyncio.gather(*[task() for _ in range(5)]) + + assert bind._ctx.get() is None + + +# TODO: abstract NullPool implementation +async def _test_null_pool(): + from gino.dialects.asyncpg import NullPool + + e = await create_engine(MYSQL_URL, pool_class=NullPool) + async with e.acquire() as conn: + raw_conn = conn.raw_connection + assert raw_conn.is_closed() + + e = await create_engine(MYSQL_URL) + async with e.acquire() as conn: + # noinspection PyProtectedMember + raw_conn = conn.raw_connection._con + assert not raw_conn.is_closed() + + +async def test_repr(): + # FIXME + # from gino.dialects.asyncpg import NullPool + # + # e = await create_engine(MYSQL_URL, pool_class=NullPool) + # assert 'cur=0' in repr(e) + # async with e.acquire(): + # assert 'cur=1' in repr(e) + # async with e.acquire(): + # assert 'cur=2' in repr(e) + # assert 'cur=1' in repr(e) + # assert 'cur=0' in repr(e) + # assert 'NullPool' in e.repr(color=True) + + e = await create_engine(MYSQL_URL) + assert 'cur=1 use=0' in repr(e) + async with e.acquire(): + assert 'cur=1 use=1' in repr(e) + async with e.acquire(): + assert 'cur=2 use=2' in repr(e) + assert 'cur=2 use=1' in repr(e) + assert 'cur=2 use=0' in repr(e) + assert 'aiomysql.pool.Pool' in e.repr(color=True) + await e.close() diff --git a/mysql_tests/test_executemany.py b/mysql_tests/test_executemany.py new file mode 100644 index 00000000..b56ea002 --- /dev/null +++ b/mysql_tests/test_executemany.py @@ -0,0 +1,107 @@ +import pytest + +from gino import MultipleResultsFound, NoResultFound +from .models import db, User + +pytestmark = pytest.mark.asyncio + + +# noinspection PyUnusedLocal +async def test_status(bind): + statement, params = db.compile( + User.insert(), [dict(name="1"), dict(name="2")]) + assert statement == ( + "INSERT INTO gino_users (name, props, type) " "VALUES (%s, %s, %s)") + assert params == (("1", '"{}"', "USER"), ("2", '"{}"', "USER")) + result = await User.insert().gino.status(dict(name="1"), dict(name="2")) + assert result is None + assert len(await User.query.gino.all()) == 2 + + +# noinspection PyUnusedLocal +async def test_all(bind): + result = await User.insert().gino.all(dict(name="1"), dict(name="2")) + assert result is None + rows = await User.query.gino.all() + assert len(rows) == 2 + assert set(u.nickname for u in rows) == {"1", "2"} + + result = await User.insert().gino.all(dict(name="3"), dict(name="4")) + assert result is None + rows = await User.query.gino.all() + assert len(rows) == 4 + assert set(u.nickname for u in rows) == {"1", "2", "3", "4"} + + +# noinspection PyUnusedLocal +async def test_first(bind): + result = await User.insert().gino.first(dict(name="1"), dict(name="2")) + assert result is None + rows = await User.query.gino.all() + assert len(await User.query.gino.all()) == 2 + assert set(u.nickname for u in rows) == {"1", "2"} + + result = await User.insert().gino.first(dict(name="3"), dict(name="4")) + assert result is None + rows = await User.query.gino.all() + assert len(rows) == 4 + assert set(u.nickname for u in rows) == {"1", "2", "3", "4"} + + +# noinspection PyUnusedLocal +async def test_one_or_none(bind): + row = await User.query.gino.one_or_none() + assert row is None + + await User.create(nickname="0") + row = await User.query.gino.one_or_none() + assert row.nickname == "0" + + result = ( + await User.insert() + .gino.one_or_none(dict(name="1"), dict(name="2")) + ) + assert result is None + rows = await User.query.gino.all() + assert len(await User.query.gino.all()) == 3 + assert set(u.nickname for u in rows) == {"0", "1", "2"} + + with pytest.raises(MultipleResultsFound): + row = await User.query.gino.one_or_none() + + +# noinspection PyUnusedLocal +async def test_one(bind): + with pytest.raises(NoResultFound): + row = await User.query.gino.one() + + await User.create(nickname="0") + row = await User.query.gino.one() + assert row.nickname == "0" + + with pytest.raises(NoResultFound): + await User.insert().gino.one(dict(name="1"), dict(name="2")) + rows = await User.query.gino.all() + assert len(await User.query.gino.all()) == 3 + assert set(u.nickname for u in rows) == {"0", "1", "2"} + + with pytest.raises(MultipleResultsFound): + row = await User.query.gino.one() + + +# noinspection PyUnusedLocal +async def test_scalar(bind): + result = ( + await User.insert() + .gino.scalar(dict(name="1"), dict(name="2")) + ) + assert result is None + rows = await User.query.gino.all() + assert len(await User.query.gino.all()) == 2 + assert set(u.nickname for u in rows) == {"1", "2"} + + result = await User.insert().gino.scalar(dict(name="3"), dict(name="4")) + assert result is None + rows = await User.query.gino.all() + assert len(rows) == 4 + assert set(u.nickname for u in rows) == {"1", "2", "3", "4"} diff --git a/mysql_tests/test_execution_options.py b/mysql_tests/test_execution_options.py new file mode 100644 index 00000000..469f5f60 --- /dev/null +++ b/mysql_tests/test_execution_options.py @@ -0,0 +1,77 @@ +import asyncio + +import pytest + +from .models import db, User, UserType + +pytestmark = pytest.mark.asyncio + + +async def test(bind): + await User.create(nickname="test") + assert isinstance(await User.query.gino.first(), User) + bind.update_execution_options(return_model=False) + assert not isinstance(await User.query.gino.first(), User) + async with db.acquire() as conn: + assert isinstance( + await conn.execution_options(return_model=True).first(User.query), User + ) + assert not isinstance( + await User.query.execution_options(return_model=False).gino.first(), User + ) + assert isinstance( + await User.query.execution_options(return_model=True).gino.first(), User + ) + assert not isinstance(await User.query.gino.first(), User) + bind.update_execution_options(return_model=True) + assert isinstance(await User.query.gino.first(), User) + + +# noinspection PyProtectedMember +async def test_compiled_first_not_found(bind): + async with bind.acquire() as conn: + with pytest.raises(LookupError, match="No such execution option"): + result = conn._execute("SELECT NOW()", (), {}) + result.context._compiled_first_opt("nonexist") + + +# noinspection PyUnusedLocal +async def test_query_ext(bind): + q = User.query + assert q.gino.query is q + + u = await User.create(nickname="test") + assert isinstance(await User.query.gino.first(), User) + + row = await User.query.gino.return_model(False).first() + assert not isinstance(row, User) + assert row == ( + u.id, + "test", + {"age": 18, "birthday": "1970-01-01T00:00:00.000000"}, + UserType.USER, + None, + ) + + row = await User.query.gino.model(None).first() + assert not isinstance(row, User) + assert row == ( + u.id, + "test", + {"age": 18, "birthday": "1970-01-01T00:00:00.000000"}, + UserType.USER, + None, + ) + + row = await db.select([User.id, User.nickname, User.type]).gino.first() + assert not isinstance(row, User) + assert row == (u.id, "test", UserType.USER) + + user = await db.select([User.id, User.nickname, User.type]).gino.model(User).first() + assert isinstance(user, User) + assert user.id is not None + assert user.nickname == "test" + assert user.type == UserType.USER + + with pytest.raises(asyncio.TimeoutError): + await db.select([db.func.SLEEP(1), User.id]).gino.timeout(0.1).status() diff --git a/mysql_tests/test_ext.py b/mysql_tests/test_ext.py new file mode 100644 index 00000000..44fd6b53 --- /dev/null +++ b/mysql_tests/test_ext.py @@ -0,0 +1,67 @@ +import collections +import importlib +import sys +import pytest + + +def installed(): + rv = 0 + for finder in sys.meta_path: + if type(finder).__name__ == "_GinoExtensionCompatFinder": + rv += 1 + return rv + + +def test_install(): + from gino import ext + + importlib.reload(ext) + + assert installed() == 1 + + ext._GinoExtensionCompatFinder().install() + assert installed() == 1 + + ext._GinoExtensionCompatFinder.uninstall() + assert not installed() + + ext._GinoExtensionCompatFinder().uninstall() + assert not installed() + + ext._GinoExtensionCompatFinder().install() + assert installed() == 1 + + ext._GinoExtensionCompatFinder().install() + assert installed() == 1 + + +def test_import(mocker): + from gino import ext + + importlib.reload(ext) + + EntryPoint = collections.namedtuple("EntryPoint", ["name", "value"]) + mocker.patch( + "gino.ext.entry_points", + new=lambda: { + "gino.extensions": [ + EntryPoint("demo", "tests.stub1"), + EntryPoint("demo2", "tests.stub2"), + ] + }, + ) + ext._GinoExtensionCompatFinder().install() + from gino.ext import demo + + assert sys.modules["tests.stub1"] is sys.modules["gino.ext.demo"] is demo + + from tests import stub2 + from gino.ext import demo2 + + assert sys.modules["tests.stub2"] is sys.modules["gino.ext.demo2"] is demo2 is stub2 + + +def test_import_error(): + with pytest.raises(ImportError, match="gino-nonexist"): + # noinspection PyUnresolvedReferences + from gino.ext import nonexist diff --git a/mysql_tests/test_iterate.py b/mysql_tests/test_iterate.py new file mode 100644 index 00000000..449de6d9 --- /dev/null +++ b/mysql_tests/test_iterate.py @@ -0,0 +1,90 @@ +from gino import UninitializedError +import pytest + +from .models import db, User + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture +def names(sa_engine): + rv = {"11", "22", "33"} + sa_engine.execute(User.__table__.insert(), [dict(name=name) for name in rv]) + yield rv + sa_engine.execute("DELETE FROM gino_users") + + +# noinspection PyUnusedLocal,PyShadowingNames +async def test_bind(bind, names): + with pytest.raises(ValueError, match="No Connection in context"): + async for u in User.query.gino.iterate(): + assert False, "Should not reach here" + with pytest.raises(ValueError, match="No Connection in context"): + await User.query.gino.iterate() + with pytest.raises(ValueError, match="No Connection in context"): + await db.iterate(User.query) + + result = set() + async with bind.transaction(): + async for u in User.query.gino.iterate(): + result.add(u.nickname) + assert names == result + + result = set() + async with bind.transaction(): + async for u in db.iterate(User.query): + result.add(u.nickname) + assert names == result + + result = set() + async with bind.transaction(): + cursor = await User.query.gino.iterate() + result.add((await cursor.next()).nickname) + assert names != result + result.update([u.nickname for u in await cursor.many(1)]) + assert names != result + result.update([u.nickname for u in await cursor.many(2)]) + assert names == result + result.update([u.nickname for u in await cursor.many(2)]) + assert names == result + assert await cursor.next() is None + + with pytest.raises(ValueError, match="too many multiparams"): + async with bind.transaction(): + await db.iterate( + User.insert(), + [dict(nickname="444"), dict(nickname="555"), dict(nickname="666"),], + ) + + result = set() + async with bind.transaction(): + cursor = await User.query.gino.iterate() + await cursor.forward(1) + result.add((await cursor.next()).nickname) + assert names != result + result.update([u.nickname for u in await cursor.many(1)]) + assert names != result + result.update([u.nickname for u in await cursor.many(2)]) + assert names != result + assert await cursor.next() is None + + +# noinspection PyUnusedLocal,PyShadowingNames +async def test_basic(engine, names): + result = set() + async with engine.transaction() as tx: + with pytest.raises(UninitializedError): + await db.iterate(User.query) + result = set() + async for u in tx.connection.iterate(User.query): + result.add(u.nickname) + async for u in tx.connection.execution_options(timeout=1).iterate(User.query): + result.add(u.nickname) + assert names == result + + result = set() + cursor = await tx.connection.iterate(User.query) + result.update([u.nickname for u in await cursor.many(2)]) + assert names != result + result.update([u.nickname for u in await cursor.many(2)]) + assert names == result diff --git a/mysql_tests/test_json.py b/mysql_tests/test_json.py new file mode 100644 index 00000000..29a217dd --- /dev/null +++ b/mysql_tests/test_json.py @@ -0,0 +1,253 @@ +import pytest +from datetime import datetime, timedelta + +from gino.exceptions import UnknownJSONPropertyError + +from .models import db, User, UserType + +pytestmark = pytest.mark.asyncio + + +async def test_in_memory(): + u = User() + assert u.age == 18 + u.age += 10 + assert u.age == 28 + assert u.balance == 0 + assert isinstance(u.balance, float) + + +# noinspection PyUnusedLocal +async def test_crud(bind): + from gino.json_support import DATETIME_FORMAT + + now = datetime.utcnow() + now_str = now.strftime(DATETIME_FORMAT) + u = await User.create(nickname="fantix", birthday=now) + u.age += 1 + assert await u.query.gino.model(None).first() == ( + 1, + "fantix", + {"age": 18, "birthday": now_str}, + UserType.USER, + None, + ) + + u = await User.get(u.id) + assert u.nickname == "fantix" + assert u.birthday == now + assert u.age == 18 + assert u.balance == 0 + assert isinstance(u.balance, float) + assert await db.select([User.birthday]).where(User.id == u.id).gino.scalar() == now + + # In-memory update, not applying + u.update(birthday=now - timedelta(days=3650)) + + # Update two JSON fields, one using expression + await u.update(age=User.age - 2, balance=100.85).apply() + + assert u.birthday == now - timedelta(days=3650) + assert u.age == 16 + assert u.balance == 100 + assert isinstance(u.balance, float) + assert await u.query.gino.model(None).first() == ( + 1, + "fantix", + dict(age=16, balance=100, birthday=now_str), + UserType.USER, + None, + ) + assert await db.select([User.realname]).where(User.id == u.id).gino.scalar() is None + + # Reload and test updating both JSON and regular property + u = await User.get(u.id) + await u.update( + age=User.age - 2, balance=200.15, realname="daisy", nickname="daisy.nick" + ).apply() + assert await u.query.gino.model(None).first() == ( + 1, + "daisy.nick", + dict(age=14, balance=200, realname="daisy", birthday=now_str), + UserType.USER, + None, + ) + assert u.to_dict() == dict( + age=14, + balance=200.0, + birthday=now, + id=1, + nickname="daisy.nick", + realname="daisy", + type=UserType.USER, + team_id=None, + ) + + # Deleting property doesn't affect database + assert u.balance == 200 + u.balance = 300 + assert u.balance == 300 + del u.balance + assert u.balance == 0 + assert await db.select([User.balance]).where(User.id == u.id).gino.scalar() == 200 + await u.update(age=22).apply() + assert u.balance == 0 + assert await db.select([User.balance]).where(User.id == u.id).gino.scalar() == 200 + await u.update(balance=None).apply() + assert u.balance == 0 + assert await db.select([User.balance]).where(User.id == u.id).gino.scalar() is None + + +# noinspection PyUnusedLocal +async def test_reload(bind): + u = await User.create() + await u.update(realname=db.cast("888", db.Unicode)).apply() + assert u.realname == "888" + await u.update(profile=None).apply() + assert u.realname == "888" + User.__dict__["realname"].reload(u) + assert u.realname is None + + +# noinspection PyUnusedLocal +async def test_properties(bind): + from gino.dialects.aiomysql import JSON + + class PropsTest(db.Model): + __tablename__ = "props_test" + profile = db.Column(JSON(), nullable=False, default="{}") + + raw = db.JSONProperty() + bool = db.BooleanProperty() + obj = db.ObjectProperty() + arr = db.ArrayProperty() + + await PropsTest.gino.create() + try: + t = await PropsTest.create( + raw=dict(a=[1, 2]), bool=True, obj=dict(x=1, y=2), arr=[3, 4, 5, 6], + ) + assert t.obj["x"] == 1 + assert t.arr[-1] == 6 + assert await db.select( + [PropsTest.profile, PropsTest.raw, PropsTest.bool] + ).gino.first() == ( + { + "arr": [3, 4, 5, 6], + "obj": {"x": 1, "y": 2}, + "raw": {"a": [1, 2]}, + "bool": True, + }, + dict(a=[1, 2]), + True, + ) + t.obj = dict(x=10, y=20) + assert t.obj["x"] == 10 + t.arr = [4, 5, 6, 7] + assert t.arr[-1] == 7 + finally: + await PropsTest.gino.drop() + + +# noinspection PyUnusedLocal +async def test_unknown_properties(bind): + from gino.dialects.aiomysql import JSON + + class PropsTest1(db.Model): + __tablename__ = "props_test1" + profile = db.Column(JSON(), nullable=False, default="{}") + bool = db.BooleanProperty() + + await PropsTest1.gino.create() + try: + # bool1 is not defined in the model + t = await PropsTest1.create(profile=dict(bool1=True)) + with pytest.raises(UnknownJSONPropertyError, match=r"bool1.*profile"): + t.to_dict() + finally: + await PropsTest1.gino.drop() + + +async def test_property_in_profile_and_attribute_collide(bind): + from gino.dialects.aiomysql import JSON + + class PropsTest2(db.Model): + __tablename__ = "props_test2" + profile = db.Column(JSON(), nullable=False, default="{}") + bool_profile = db.BooleanProperty() + bool_attr = db.Column(db.Boolean) + + await PropsTest2.gino.create() + try: + await PropsTest2.create( + profile={"bool_attr": False, "bool_profile": True}, bool_attr=True + ) + # bool_attr is defined in the model + # bool_profile is defined as json property + t2 = await PropsTest2.query.gino.first() + + assert t2.bool_attr is True + with pytest.raises(UnknownJSONPropertyError, match=r"bool_attr"): + assert t2.bool_profile is True + finally: + await PropsTest2.gino.drop() + + +async def test_no_profile(): + with pytest.raises(AttributeError, match=r"JSON\[B\] column"): + # noinspection PyUnusedLocal + class Test(db.Model): + __tablename__ = "tests_no_profile" + + id = db.Column(db.BigInteger(), primary_key=True) + age = db.IntegerProperty(default=18) + + +async def test_t291_t402(bind): + from gino.dialects.aiomysql import JSON + + class CustomJSON(db.TypeDecorator): + impl = JSON + + def process_result_value(self, *_): + return 123 + + class PropsTest(db.Model): + __tablename__ = "props_test_291" + profile = db.Column(JSON(), nullable=False, default={}) + profile1 = db.Column(JSON(), nullable=False, default={}) + profile2 = db.Column(CustomJSON(), nullable=False, default={}) + + bool = db.BooleanProperty() + bool1 = db.BooleanProperty(prop_name="profile1") + + await PropsTest.gino.create() + try: + await PropsTest.create(bool=True, bool1=True) + profile1 = await bind.scalar("SELECT profile1 FROM props_test_291") + assert isinstance(profile1, dict) + profile2 = await bind.scalar("SELECT profile2 FROM props_test_291") + assert isinstance(profile2, dict) + custom_profile2 = await bind.scalar(PropsTest.select("profile2")) + assert isinstance(custom_profile2, int) + assert custom_profile2 == 123 + finally: + await PropsTest.gino.drop() + + +async def test_json_path(bind): + from gino.dialects.aiomysql import JSON + + class PathTest(db.Model): + __tablename__ = "path_test_json_path" + data = db.Column(JSON()) + + await PathTest.gino.create() + try: + t1 = await PathTest.create(data=dict(a=dict(b="c"))) + t2 = await PathTest.query.where( + PathTest.data[("a", "b")] == "c" + ).gino.first() + assert t1.data == t2.data + finally: + await PathTest.gino.drop() diff --git a/mysql_tests/test_loader.py b/mysql_tests/test_loader.py new file mode 100644 index 00000000..68de6d20 --- /dev/null +++ b/mysql_tests/test_loader.py @@ -0,0 +1,402 @@ +import random +from datetime import datetime + +import pytest +from async_generator import yield_, async_generator + +from gino.loader import AliasLoader +from sqlalchemy import select +from sqlalchemy.sql.functions import count +from .models import ( + db, + User, + Team, + TeamWithDefaultCompany, + TeamWithoutMembersSetter, + Company, + CompanyWithoutTeamsSetter, +) + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture +@async_generator +async def user(bind): + c = await Company.create() + t1 = await Team.create(company_id=c.id) + t2 = await Team.create(company_id=c.id, parent_id=t1.id) + t3 = await Team.create(company_id=c.id, parent_id=t1.id) + u = await User.create(team_id=t2.id) + u.team = t2 + t2.parent = t1 + t2.company = c + t1.company = c + await yield_(u) + await User.delete.gino.status() + await Team.delete.gino.status() + await Company.delete.gino.status() + + +async def test_model_alternative(user): + u = await User.query.gino.load(User).first() + assert isinstance(u, User) + assert u.id == user.id + assert u.nickname == user.nickname + + +async def test_scalar(user): + name = await User.query.gino.load(User.nickname).first() + assert user.nickname == name + + uid, name = await User.query.gino.load((User.id, User.nickname)).first() + assert user.id == uid + assert user.nickname == name + + +async def test_one_or_none(user): + name = await User.query.gino.load(User.nickname).one_or_none() + assert user.nickname == name + + uid, name = await (User.query.gino.load((User.id, User.nickname)).one_or_none()) + assert user.id == uid + assert user.nickname == name + + +async def test_one(user): + name = await User.query.gino.load(User.nickname).one() + assert user.nickname == name + + uid, name = await User.query.gino.load((User.id, User.nickname)).one() + assert user.id == uid + assert user.nickname == name + + +async def test_model_load(user): + u = await User.query.gino.load(User.load("nickname", User.team_id)).first() + assert isinstance(u, User) + assert u.id is None + assert u.nickname == user.nickname + assert u.team_id == user.team.id + + with pytest.raises(TypeError): + await User.query.gino.load(User.load(123)).first() + + with pytest.raises(AttributeError): + await User.query.gino.load(User.load(Team.id)).first() + + +async def test_216_model_load_passive_partial(user): + u = await db.select([User.nickname]).gino.model(User).first() + assert isinstance(u, User) + assert u.id is None + assert u.nickname == user.nickname + + +async def test_load_relationship(user): + u = await User.outerjoin(Team).select().gino.load(User.load(team=Team)).first() + assert isinstance(u, User) + assert u.id == user.id + assert u.nickname == user.nickname + assert isinstance(u.team, Team) + assert u.team.id == user.team.id + assert u.team.name == user.team.name + + +@pytest.mark.parametrize("team_cls", [Team, TeamWithDefaultCompany]) +async def test_load_nested(user, team_cls): + for u in ( + await User.outerjoin(team_cls) + .outerjoin(Company) + .select() + .gino.load(User.load(team=team_cls.load(company=Company))) + .first(), + await User.load(team=team_cls.load(company=Company)).gino.first(), + await User.load( + team=team_cls.load(company=Company.on(team_cls.company_id == Company.id)) + ).gino.first(), + await User.load( + team=team_cls.load(company=Company).on(User.team_id == team_cls.id) + ).gino.first(), + await User.load( + team=team_cls.on(User.team_id == team_cls.id).load(company=Company) + ).gino.first(), + ): + assert isinstance(u, User) + assert u.id == user.id + assert u.nickname == user.nickname + assert isinstance(u.team, team_cls) + assert u.team.id == user.team.id + assert u.team.name == user.team.name + assert isinstance(u.team.company, Company) + assert u.team.company.id == user.team.company.id + assert u.team.company.name == user.team.company.name + + +async def test_func(user): + def loader(row, context): + rv = User(id=row[User.id], nickname=row[User.nickname]) + rv.team = Team(id=row[Team.id], name=row[Team.name]) + rv.team.company = Company(id=row[Company.id], name=row[Company.name]) + return rv + + u = await User.outerjoin(Team).outerjoin(Company).select().gino.load(loader).first() + assert isinstance(u, User) + assert u.id == user.id + assert u.nickname == user.nickname + assert isinstance(u.team, Team) + assert u.team.id == user.team.id + assert u.team.name == user.team.name + assert isinstance(u.team.company, Company) + assert u.team.company.id == user.team.company.id + assert u.team.company.name == user.team.company.name + + +async def test_adjacency_list(user): + group = Team.alias() + + with pytest.raises(AttributeError): + group.non_exist() + + # noinspection PyUnusedLocal + def loader(row, context): + rv = User(id=row[User.id], nickname=row[User.nickname]) + rv.team = Team(id=row[Team.id], name=row[Team.name]) + rv.team.parent = Team(id=row[group.id], name=row[group.name]) + return rv + + for exp in ( + loader, + User.load(team=Team.load(parent=group)), + User.load(team=Team.load(parent=group.load("id", "name"))), + User.load(team=Team.load(parent=group.load())), + ): + u = ( + await User.outerjoin(Team) + .outerjoin(group, Team.parent_id == group.id) + .select() + .gino.load(exp) + .first() + ) + + assert isinstance(u, User) + assert u.id == user.id + assert u.nickname == user.nickname + assert isinstance(u.team, Team) + assert u.team.id == user.team.id + assert u.team.name == user.team.name + assert isinstance(u.team.parent, Team) + assert u.team.parent.id == user.team.parent.id + assert u.team.parent.name == user.team.parent.name + + +async def test_alias_distinct(user): + group = Team.alias() + group_company = Company.alias() + t1, t2, t3 = ( + await Team.outerjoin(Company) + .outerjoin(group, Team.parent_id == group.id) + .outerjoin(group_company, group.company_id == group_company.id) + .select() + .order_by(Team.id) + .gino.load( + Team.distinct(Team.id).load( + company=Company.distinct(Company.id), + parent=group.distinct(group.id).load( + company=group_company.distinct(group_company.id) + ), + ) + ) + .all() + ) + assert t2.parent.name == t1.name + assert t1.company is t2.company + assert t2.parent.company is t3.parent.company + + +async def test_alias_loader_columns(user): + user_alias = User.alias() + base_query = user_alias.outerjoin(Team).select() + + query = base_query.execution_options(loader=AliasLoader(user_alias, "id")) + u = await query.gino.first() + assert u.id is not None + + +async def test_multiple_models_in_one_query(bind): + for _ in range(3): + await User.create() + + ua1 = User.alias() + ua2 = User.alias() + join_query = select([ua1, ua2]).where(ua1.id < ua2.id) + result = await join_query.gino.load((ua1.load("id"), ua2.load("id"))).all() + assert len(result) == 3 + for u1, u2 in result: + assert u1.id is not None + assert u2.id is not None + assert u1.id < u2.id + + +async def test_loader_with_aggregation(user): + count_col = count().label("count") + user_count = select([User.team_id, count_col]).group_by(User.team_id).alias() + query = Team.outerjoin(user_count).select() + result = await query.gino.load( + (Team.id, Team.name, user_count.columns.team_id, count_col) + ).all() + assert len(result) == 3 + # team 1/3 doesn't have users, team 2 has 1 user + # third and forth columns are None for team 1/3 + for team_id, team_name, user_team_id, user_count in result: + if team_id == user.team_id: + assert team_name == user.team.name + assert user_team_id == user.team_id + assert user_count == 1 + else: + assert team_id is not None + assert team_name is not None + assert user_team_id is None + assert user_count is None + + +async def test_adjacency_list_query_builder(user): + group = Team.alias() + u = await User.load( + team=Team.load(parent=group.on(Team.parent_id == group.id)) + ).gino.first() + + assert isinstance(u, User) + assert u.id == user.id + assert u.nickname == user.nickname + assert isinstance(u.team, Team) + assert u.team.id == user.team.id + assert u.team.name == user.team.name + assert isinstance(u.team.parent, Team) + assert u.team.parent.id == user.team.parent.id + assert u.team.parent.name == user.team.parent.name + + +async def test_literal(user): + sample = tuple(random.random() for _ in range(5)) + now = db.Column("time", db.DateTime()) + row = await db.first( + db.text("SELECT UTC_TIMESTAMP") + .columns(now) + .gino.load(sample + (lambda r, c: datetime.utcnow(), now)) + .query + ) + assert row[:5] == sample + assert isinstance(row[-2], datetime) + assert isinstance(row[-1], datetime) + assert row[-1] <= row[-2] + + +@pytest.mark.parametrize( + ["team_cls", "company_cls"], + [(Team, Company), (TeamWithoutMembersSetter, CompanyWithoutTeamsSetter)], +) +async def test_load_one_to_many(user, team_cls, company_cls): + # noinspection PyListCreation + uids = [user.id] + uids.append((await User.create(nickname="1", team_id=user.team.id)).id) + uids.append((await User.create(nickname="1", team_id=user.team.id)).id) + uids.append((await User.create(nickname="2", team_id=user.team.parent.id)).id) + query = User.outerjoin(team_cls).outerjoin(company_cls).select() + companies = await query.gino.load( + company_cls.distinct(company_cls.id).load( + add_team=team_cls.load(add_member=User).distinct(team_cls.id) + ) + ).all() + assert len(companies) == 1 + company = companies[0] + assert isinstance(company, Company) + assert company.id == user.team.company_id + assert company.name == user.team.company.name + assert len(company.teams) == 2 + for team in company.teams: + if team.id == user.team.id: + assert len(team.members) == 3 + for u in team.members: + if u.nickname == user.nickname: + assert isinstance(u, User) + assert u.id == user.id + uids.remove(u.id) + if u.nickname in {"1", "2"}: + uids.remove(u.id) + else: + assert len(team.members) == 1 + uids.remove(list(team.members)[0].id) + assert uids == [] + + # test distinct many-to-one + query = User.outerjoin(team_cls).select().where(team_cls.id == user.team.id) + users = await query.gino.load(User.load(team=team_cls.distinct(team_cls.id))).all() + assert len(users) == 3 + assert users[0].team is users[1].team + assert users[0].team is users[2].team + + +async def test_distinct_none(bind): + u = await User.create() + + query = User.outerjoin(Team).select().where(User.id == u.id) + loader = User.load(team=Team) + + u = await query.gino.load(loader).first() + assert not hasattr(u, "team") + + u = await User.load(team=Team).query.where(User.id == u.id).gino.first() + assert not hasattr(u, "team") + + query = User.outerjoin(Team).select().where(User.id == u.id) + loader = User.load(team=Team.distinct(Team.id)) + + u = await query.gino.load(loader).first() + assert not hasattr(u, "team") + + +async def test_tuple_loader_279(user): + from gino.loader import TupleLoader + + query = db.select([User, Team]) + async with db.transaction(): + async for row in query.gino.load((User, Team)).iterate(): + assert len(row) == 2 + async for row in query.gino.load(TupleLoader((User, Team))).iterate(): + assert len(row) == 2 + + +async def test_none_as_none_281(user): + import gino + + if gino.__version__ < "0.9": + query = Team.outerjoin(User).select() + loader = Team, User.none_as_none() + assert any(row[1] is None for row in await query.gino.load(loader).all()) + + loader = Team.distinct(Team.id).load(add_member=User.none_as_none()) + assert any(not team.members for team in await query.gino.load(loader).all()) + + if gino.__version__ >= "0.8.0": + query = Team.outerjoin(User).select() + loader = Team, User + assert any(row[1] is None for row in await query.gino.load(loader).all()) + + loader = Team.distinct(Team.id).load(add_member=User) + assert any(not team.members for team in await query.gino.load(loader).all()) + + +async def test_model_in_query(user): + query = select([User], from_obj=User.outerjoin(Team)) + query = query.where(Team.company_id == user.team.company.id) + + query = query.alias("users") + User1 = User.in_query(query) + + query = query.outerjoin(Team).outerjoin(Company).select() + loader = User1.distinct(User1.id).load() + users = await query.gino.load(loader).all() + assert users[0] != user + assert users[0].id == user.id + assert users[0].nickname == user.nickname diff --git a/mysql_tests/test_schema.py b/mysql_tests/test_schema.py new file mode 100644 index 00000000..5b0196e4 --- /dev/null +++ b/mysql_tests/test_schema.py @@ -0,0 +1,63 @@ +from enum import Enum + +import pytest + +import gino +from gino.dialects.aiomysql import AsyncEnum + +pytestmark = pytest.mark.asyncio +db = gino.Gino() + + +class MyEnum(Enum): + ONE = "one" + TWO = "two" + + +class Blog(db.Model): + __tablename__ = "s_blog" + + id = db.Column(db.BigInteger(), primary_key=True) + title = db.Column(db.Unicode(255), index=True, comment="Title Comment") + visits = db.Column(db.BigInteger(), default=0) + comment_id = db.Column(db.ForeignKey("s_comment.id")) + number = db.Column(db.Enum(MyEnum), nullable=False, default=MyEnum.TWO) + number2 = db.Column(AsyncEnum(MyEnum), nullable=False, default=MyEnum.TWO) + + +class Comment(db.Model): + __tablename__ = "s_comment" + + id = db.Column(db.BigInteger(), primary_key=True) + blog_id = db.Column(db.ForeignKey("s_blog.id", name="blog_id_fk")) + + +blog_seq = db.Sequence("blog_seq", metadata=db, schema="schema_test") + + +async def test(engine, define=True): + async with engine.acquire() as conn: + assert not await engine.dialect.has_table(conn, "non_exist") + Blog.__table__.comment = "Blog Comment" + db.bind = engine + await db.gino.create_all() + await Blog.number.type.create_async(engine, checkfirst=True) + await Blog.number2.type.create_async(engine, checkfirst=True) + await db.gino.create_all(tables=[Blog.__table__], checkfirst=True) + await blog_seq.gino.create(checkfirst=True) + await Blog.__table__.gino.create(checkfirst=True) + await db.gino.drop_all() + await db.gino.drop_all(tables=[Blog.__table__], checkfirst=True) + await Blog.__table__.gino.drop(checkfirst=True) + await blog_seq.gino.drop(checkfirst=True) + + if define: + + class Comment2(db.Model): + __tablename__ = "s_comment_2" + + id = db.Column(db.BigInteger(), primary_key=True) + blog_id = db.Column(db.ForeignKey("s_blog.id")) + + await db.gino.create_all() + await db.gino.drop_all() diff --git a/mysql_tests/test_transaction.py b/mysql_tests/test_transaction.py new file mode 100644 index 00000000..452ad521 --- /dev/null +++ b/mysql_tests/test_transaction.py @@ -0,0 +1,265 @@ +import pytest + +from .models import db, User, qsize + +pytestmark = pytest.mark.asyncio + + +async def _init(bind): + from .test_crud import test_create + + u = await test_create(bind) + + def get_name(): + return User.select("nickname").where(User.id == u.id).gino.scalar() + + return u, get_name + + +async def test_connection_ctx(bind, mocker): + init_size = qsize(bind) + u, get_name = await _init(bind) + + assert await get_name() != "commit" + + async with bind.acquire() as conn: + async with conn.transaction(): + await u.update(nickname="commit").apply() + assert await get_name() == "commit" + + with pytest.raises(ZeroDivisionError): + async with bind.acquire() as conn: + async with conn.transaction(): + await u.update(nickname="rollback").apply() + assert await get_name() == "rollback" + raise ZeroDivisionError + + assert await get_name() == "commit" + + async with bind.acquire() as conn: + tx = await conn.transaction().__aenter__() + await u.update(nickname="rollback").apply() + assert await get_name() == "rollback" + mocker.patch("aiomysql.connection.Connection.commit").side_effect = IndexError + with pytest.raises(IndexError): + await tx.__aexit__(None, None, None) + # clean up, and to simulate commit failed + mocker.stopall() + await tx._tx.rollback() + assert await get_name() == "commit" + assert await get_name() == "commit" + + assert init_size == qsize(bind) + + +async def test_connection_await(bind): + init_size = qsize(bind) + u, get_name = await _init(bind) + + assert await get_name() != "commit" + + async with bind.acquire() as conn: + tx = await conn.transaction() + await u.update(nickname="commit").apply() + await tx.commit() + assert await get_name() == "commit" + + async with bind.acquire() as conn: + tx = await conn.transaction() + await u.update(nickname="rollback").apply() + assert await get_name() == "rollback" + await tx.rollback() + + assert await get_name() == "commit" + + # Neither commit nor rollback, should rollback + async with bind.acquire() as conn: + await conn.transaction() + await u.update(nickname="rollback").apply() + assert await get_name() == "rollback" + + assert await get_name() == "commit" + + assert init_size == qsize(bind) + + +async def test_engine(bind): + init_size = qsize(bind) + u, get_name = await _init(bind) + + assert await get_name() != "commit" + + async with bind.transaction(): + await u.update(nickname="commit").apply() + assert await get_name() == "commit" + + with pytest.raises(ZeroDivisionError): + async with bind.transaction(): + await u.update(nickname="rollback").apply() + raise ZeroDivisionError + assert await get_name() == "commit" + assert init_size == qsize(bind) + + +async def test_begin_failed(bind, mocker): + from aiomysql.connection import Connection + + init_size = qsize(bind) + mocker.patch("aiomysql.connection.Connection.begin") + Connection.begin.side_effect = ZeroDivisionError + with pytest.raises(ZeroDivisionError): + async with bind.transaction(): + pass # pragma: no cover + assert init_size == qsize(bind) + + +async def test_commit_failed(bind, mocker): + from aiomysql.connection import Connection + + init_size = qsize(bind) + mocker.patch("aiomysql.connection.Connection.begin") + # noinspection PyUnresolvedReferences,PyProtectedMember + Connection.begin.side_effect = ZeroDivisionError + with pytest.raises(ZeroDivisionError): + async with bind.transaction(): + pass + assert init_size == qsize(bind) + + +async def test_reuse(bind): + from aiomysql.connection import Connection + + init_size = qsize(bind) + async with db.acquire() as conn: + async with db.transaction() as tx: + assert tx.connection.raw_connection is conn.raw_connection + assert isinstance(tx.raw_transaction, Connection) + async with db.transaction() as tx2: + assert tx2.connection.raw_connection is conn.raw_connection + async with db.transaction(reuse=False) as tx2: + assert tx2.connection.raw_connection is not conn.raw_connection + async with db.transaction(reuse=False) as tx: + assert tx.connection.raw_connection is not conn.raw_connection + async with db.transaction() as tx2: + assert tx2.connection.raw_connection is tx.connection.raw_connection + async with db.transaction(reuse=False) as tx2: + assert tx2.connection.raw_connection is not conn.raw_connection + assert tx2.connection.raw_connection is not tx.connection.raw_connection + with pytest.raises(ValueError, match="already released"): + await conn.release() + assert init_size == qsize(bind) + + +async def test_nested(bind): + init_size = qsize(bind) + u, get_name = await _init(bind) + + name = await get_name() + assert u.nickname == name + + async with bind.transaction(): + await u.update(nickname="first").apply() + async with bind.transaction(): + pass + + assert init_size == qsize(bind) + + +# noinspection PyUnreachableCode,PyUnusedLocal +async def test_early_end(bind): + init_size = qsize(bind) + u, get_name = await _init(bind) + + assert await get_name() != "ininin" + + async with bind.transaction() as tx: + async with bind.transaction(): + async with bind.transaction(): + await u.update(nickname="ininin").apply() + tx.raise_commit() + assert False, "Should not reach here" + assert False, "Should not reach here" + assert False, "Should not reach here" + + assert await get_name() == "ininin" + assert init_size == qsize(bind) + + async with bind.transaction() as tx: + async with bind.transaction(): + async with bind.transaction(): + await u.update(nickname="nonono").apply() + assert await get_name() == "nonono" + tx.raise_rollback() + assert False, "Should not reach here" + assert False, "Should not reach here" + assert False, "Should not reach here" + + assert await get_name() == "ininin" + assert init_size == qsize(bind) + + reached = 0 + + async with bind.transaction(): + async with bind.transaction() as tx: + async with bind.transaction(): + await u.update(nickname="nonono").apply() + assert await get_name() == "nonono" + tx.raise_rollback() + assert False, "Should not reach here" + assert False, "Should not reach here" + reached += 1 + assert await get_name() == "ininin" + + assert await get_name() == "ininin" + assert init_size == qsize(bind) + assert reached == 1 + + async with bind.transaction(): + async with bind.transaction() as tx: + async with bind.transaction(): + await u.update(nickname="nonono").apply() + assert await get_name() == "nonono" + tx.raise_commit() + assert False, "Should not reach here" + assert False, "Should not reach here" + reached += 1 + assert await get_name() == "nonono" + + assert await get_name() == "nonono" + assert init_size == qsize(bind) + assert reached == 2 + + +# noinspection PyUnreachableCode +async def test_end_raises_in_with(engine): + async with engine.transaction() as tx: + with pytest.raises(AssertionError, match="Illegal in managed mode"): + await tx.commit() + await tx.raise_commit() + assert False, "Should not reach here" + + async with engine.transaction() as tx: + with pytest.raises(AssertionError, match="Illegal in managed mode"): + await tx.rollback() + await tx.raise_rollback() + assert False, "Should not reach here" + + +async def test_base_exception(engine): + async with engine.transaction() as tx: + # noinspection PyBroadException + try: + await tx.raise_commit() + except Exception: + assert False, "Should not reach here" + assert False, "Should not reach here" + + +async def test_no_rollback_on_commit_fail(engine, mocker): + mocker.patch("aiomysql.connection.Connection.commit").side_effect = IndexError + async with engine.acquire() as conn: + tx = await conn.transaction().__aenter__() + rollback = mocker.patch.object(tx._tx, "rollback") + with pytest.raises(IndexError): + await tx.__aexit__(None, None, None) + assert not rollback.called diff --git a/poetry.lock b/poetry.lock index 4caba9b0..3044dbc2 100644 --- a/poetry.lock +++ b/poetry.lock @@ -34,6 +34,20 @@ version = ">=3.6.5" [package.extras] speedups = ["aiodns", "brotlipy", "cchardet"] +[[package]] +category = "main" +description = "MySQL driver for asyncio." +name = "aiomysql" +optional = false +python-versions = "*" +version = "0.0.20" + +[package.dependencies] +PyMySQL = ">=0.9,<=0.9.2" + +[package.extras] +sa = ["sqlalchemy (>=1.0)"] + [[package]] category = "dev" description = "A configurable sidebar-enabled Sphinx theme" @@ -159,7 +173,18 @@ description = "Python package for providing Mozilla's CA Bundle." name = "certifi" optional = false python-versions = "*" -version = "2020.4.5.2" +version = "2020.6.20" + +[[package]] +category = "main" +description = "Foreign Function Interface for Python calling C code." +name = "cffi" +optional = false +python-versions = "*" +version = "1.14.0" + +[package.dependencies] +pycparser = "*" [[package]] category = "main" @@ -209,6 +234,25 @@ version = "5.1" [package.extras] toml = ["toml"] +[[package]] +category = "main" +description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers." +name = "cryptography" +optional = false +python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*" +version = "2.9.2" + +[package.dependencies] +cffi = ">=1.8,<1.11.3 || >1.11.3" +six = ">=1.4.1" + +[package.extras] +docs = ["sphinx (>=1.6.5,<1.8.0 || >1.8.0)", "sphinx-rtd-theme"] +docstest = ["doc8", "pyenchant (>=1.6.11)", "twine (>=1.12.0)", "sphinxcontrib-spelling (>=4.0.1)"] +idna = ["idna (>=2.1)"] +pep8test = ["flake8", "flake8-import-order", "pep8-naming"] +test = ["pytest (>=3.6.0,<3.9.0 || >3.9.0,<3.9.1 || >3.9.1,<3.9.2 || >3.9.2)", "pretend", "iso8601", "pytz", "hypothesis (>=1.11.4,<3.79.2 || >3.79.2)"] + [[package]] category = "dev" description = "Docutils -- Python Documentation Utilities" @@ -320,7 +364,7 @@ marker = "python_version >= \"3.6\" and python_version < \"4.0\"" name = "hstspreload" optional = true python-versions = ">=3.6" -version = "2020.6.9" +version = "2020.6.16" [[package]] category = "main" @@ -525,6 +569,14 @@ optional = false python-versions = "*" version = "0.4.3" +[[package]] +category = "main" +description = "Python interface to MySQL" +name = "mysqlclient" +optional = false +python-versions = "*" +version = "1.4.6" + [[package]] category = "dev" description = "Core utilities for Python packages" @@ -621,7 +673,15 @@ description = "library with cross-python path, ini-parsing, io, code, log facili name = "py" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -version = "1.8.1" +version = "1.8.2" + +[[package]] +category = "main" +description = "C parser in Python" +name = "pycparser" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +version = "2.20" [[package]] category = "dev" @@ -631,6 +691,17 @@ optional = false python-versions = ">=3.5" version = "2.6.1" +[[package]] +category = "main" +description = "Pure Python MySQL Driver" +name = "pymysql" +optional = false +python-versions = "*" +version = "0.9.2" + +[package.dependencies] +cryptography = "*" + [[package]] category = "dev" description = "Python parsing module" @@ -769,7 +840,7 @@ description = "Python HTTP for Humans." name = "requests" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" -version = "2.23.0" +version = "2.24.0" [package.dependencies] certifi = ">=2017.4.17" @@ -818,7 +889,7 @@ docs = ["sphinx (>=2.1.2)", "sphinx-rtd-theme", "recommonmark (>=0.5.0)", "docut test = ["pytest (5.2.1)", "multidict (>=4.0,<5.0)", "gunicorn", "pytest-cov", "httpcore (0.3.0)", "beautifulsoup4", "pytest-sanic", "pytest-sugar", "pytest-benchmark", "uvloop (>=0.5.3)", "ujson (>=1.35)"] [[package]] -category = "dev" +category = "main" description = "Python 2 and 3 compatibility utilities" name = "six" optional = false @@ -1210,7 +1281,7 @@ starlette = ["gino-starlette"] tornado = ["gino-tornado"] [metadata] -content-hash = "603e4fe024e4deaf202b809d5d6bff9029c3985349e96b9a2191bf5518c31965" +content-hash = "fdc895a865b58bd3f44060782796a51b91aa81b11b4dd767fe8371ab13d4842a" python-versions = "^3.5" [metadata.files] @@ -1232,6 +1303,10 @@ aiohttp = [ {file = "aiohttp-3.6.2-py3-none-any.whl", hash = "sha256:460bd4237d2dbecc3b5ed57e122992f60188afe46e7319116da5eb8a9dfedba4"}, {file = "aiohttp-3.6.2.tar.gz", hash = "sha256:259ab809ff0727d0e834ac5e8a283dc5e3e0ecc30c4d80b3cd17a4139ce1f326"}, ] +aiomysql = [ + {file = "aiomysql-0.0.20-py3-none-any.whl", hash = "sha256:5fd798481f16625b424eec765c56d712ac78a51f3bd0175a3de94107aae43307"}, + {file = "aiomysql-0.0.20.tar.gz", hash = "sha256:d89ce25d44dadb43cf2d9e4603bd67b7a0ad12d5e67208de013629ba648df2ba"}, +] alabaster = [ {file = "alabaster-0.7.12-py2.py3-none-any.whl", hash = "sha256:446438bdcca0e05bd45ea2de1668c1d9b032e1a9154c2c259092d77031ddd359"}, {file = "alabaster-0.7.12.tar.gz", hash = "sha256:a661d72d58e6ea8a57f7a86e37d86716863ee5e92788398526d58b26a4e4dc02"}, @@ -1295,8 +1370,38 @@ blinker = [ {file = "blinker-1.4.tar.gz", hash = "sha256:471aee25f3992bd325afa3772f1063dbdbbca947a041b8b89466dc00d606f8b6"}, ] certifi = [ - {file = "certifi-2020.4.5.2-py2.py3-none-any.whl", hash = "sha256:9cd41137dc19af6a5e03b630eefe7d1f458d964d406342dd3edf625839b944cc"}, - {file = "certifi-2020.4.5.2.tar.gz", hash = "sha256:5ad7e9a056d25ffa5082862e36f119f7f7cec6457fa07ee2f8c339814b80c9b1"}, + {file = "certifi-2020.6.20-py2.py3-none-any.whl", hash = "sha256:8fc0819f1f30ba15bdb34cceffb9ef04d99f420f68eb75d901e9560b8749fc41"}, + {file = "certifi-2020.6.20.tar.gz", hash = "sha256:5930595817496dd21bb8dc35dad090f1c2cd0adfaf21204bf6732ca5d8ee34d3"}, +] +cffi = [ + {file = "cffi-1.14.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:1cae98a7054b5c9391eb3249b86e0e99ab1e02bb0cc0575da191aedadbdf4384"}, + {file = "cffi-1.14.0-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:cf16e3cf6c0a5fdd9bc10c21687e19d29ad1fe863372b5543deaec1039581a30"}, + {file = "cffi-1.14.0-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:f2b0fa0c01d8a0c7483afd9f31d7ecf2d71760ca24499c8697aeb5ca37dc090c"}, + {file = "cffi-1.14.0-cp27-cp27m-win32.whl", hash = "sha256:99f748a7e71ff382613b4e1acc0ac83bf7ad167fb3802e35e90d9763daba4d78"}, + {file = "cffi-1.14.0-cp27-cp27m-win_amd64.whl", hash = "sha256:c420917b188a5582a56d8b93bdd8e0f6eca08c84ff623a4c16e809152cd35793"}, + {file = "cffi-1.14.0-cp27-cp27mu-manylinux1_i686.whl", hash = "sha256:399aed636c7d3749bbed55bc907c3288cb43c65c4389964ad5ff849b6370603e"}, + {file = "cffi-1.14.0-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:cab50b8c2250b46fe738c77dbd25ce017d5e6fb35d3407606e7a4180656a5a6a"}, + {file = "cffi-1.14.0-cp35-cp35m-macosx_10_9_x86_64.whl", hash = "sha256:001bf3242a1bb04d985d63e138230802c6c8d4db3668fb545fb5005ddf5bb5ff"}, + {file = "cffi-1.14.0-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:e56c744aa6ff427a607763346e4170629caf7e48ead6921745986db3692f987f"}, + {file = "cffi-1.14.0-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:b8c78301cefcf5fd914aad35d3c04c2b21ce8629b5e4f4e45ae6812e461910fa"}, + {file = "cffi-1.14.0-cp35-cp35m-win32.whl", hash = "sha256:8c0ffc886aea5df6a1762d0019e9cb05f825d0eec1f520c51be9d198701daee5"}, + {file = "cffi-1.14.0-cp35-cp35m-win_amd64.whl", hash = "sha256:8a6c688fefb4e1cd56feb6c511984a6c4f7ec7d2a1ff31a10254f3c817054ae4"}, + {file = "cffi-1.14.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:95cd16d3dee553f882540c1ffe331d085c9e629499ceadfbda4d4fde635f4b7d"}, + {file = "cffi-1.14.0-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:66e41db66b47d0d8672d8ed2708ba91b2f2524ece3dee48b5dfb36be8c2f21dc"}, + {file = "cffi-1.14.0-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:028a579fc9aed3af38f4892bdcc7390508adabc30c6af4a6e4f611b0c680e6ac"}, + {file = "cffi-1.14.0-cp36-cp36m-win32.whl", hash = "sha256:cef128cb4d5e0b3493f058f10ce32365972c554572ff821e175dbc6f8ff6924f"}, + {file = "cffi-1.14.0-cp36-cp36m-win_amd64.whl", hash = "sha256:337d448e5a725bba2d8293c48d9353fc68d0e9e4088d62a9571def317797522b"}, + {file = "cffi-1.14.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e577934fc5f8779c554639376beeaa5657d54349096ef24abe8c74c5d9c117c3"}, + {file = "cffi-1.14.0-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:62ae9af2d069ea2698bf536dcfe1e4eed9090211dbaafeeedf5cb6c41b352f66"}, + {file = "cffi-1.14.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:14491a910663bf9f13ddf2bc8f60562d6bc5315c1f09c704937ef17293fb85b0"}, + {file = "cffi-1.14.0-cp37-cp37m-win32.whl", hash = "sha256:c43866529f2f06fe0edc6246eb4faa34f03fe88b64a0a9a942561c8e22f4b71f"}, + {file = "cffi-1.14.0-cp37-cp37m-win_amd64.whl", hash = "sha256:2089ed025da3919d2e75a4d963d008330c96751127dd6f73c8dc0c65041b4c26"}, + {file = "cffi-1.14.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:3b911c2dbd4f423b4c4fcca138cadde747abdb20d196c4a48708b8a2d32b16dd"}, + {file = "cffi-1.14.0-cp38-cp38-manylinux1_i686.whl", hash = "sha256:7e63cbcf2429a8dbfe48dcc2322d5f2220b77b2e17b7ba023d6166d84655da55"}, + {file = "cffi-1.14.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:3d311bcc4a41408cf5854f06ef2c5cab88f9fded37a3b95936c9879c1640d4c2"}, + {file = "cffi-1.14.0-cp38-cp38-win32.whl", hash = "sha256:675686925a9fb403edba0114db74e741d8181683dcf216be697d208857e04ca8"}, + {file = "cffi-1.14.0-cp38-cp38-win_amd64.whl", hash = "sha256:00789914be39dffba161cfc5be31b55775de5ba2235fe49aa28c148236c4e06b"}, + {file = "cffi-1.14.0.tar.gz", hash = "sha256:2d384f4a127a15ba701207f7639d94106693b6cd64173d6c8988e2c25f3ac2b6"}, ] chardet = [ {file = "chardet-3.0.4-py2.py3-none-any.whl", hash = "sha256:fc323ffcaeaed0e0a02bf4d117757b98aed530d9ed4531e3e15460124c106691"}, @@ -1346,6 +1451,27 @@ coverage = [ {file = "coverage-5.1-cp39-cp39-win_amd64.whl", hash = "sha256:bb28a7245de68bf29f6fb199545d072d1036a1917dca17a1e75bbb919e14ee8e"}, {file = "coverage-5.1.tar.gz", hash = "sha256:f90bfc4ad18450c80b024036eaf91e4a246ae287701aaa88eaebebf150868052"}, ] +cryptography = [ + {file = "cryptography-2.9.2-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:daf54a4b07d67ad437ff239c8a4080cfd1cc7213df57d33c97de7b4738048d5e"}, + {file = "cryptography-2.9.2-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:3b3eba865ea2754738616f87292b7f29448aec342a7c720956f8083d252bf28b"}, + {file = "cryptography-2.9.2-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:c447cf087cf2dbddc1add6987bbe2f767ed5317adb2d08af940db517dd704365"}, + {file = "cryptography-2.9.2-cp27-cp27m-win32.whl", hash = "sha256:f118a95c7480f5be0df8afeb9a11bd199aa20afab7a96bcf20409b411a3a85f0"}, + {file = "cryptography-2.9.2-cp27-cp27m-win_amd64.whl", hash = "sha256:c4fd17d92e9d55b84707f4fd09992081ba872d1a0c610c109c18e062e06a2e55"}, + {file = "cryptography-2.9.2-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:d0d5aeaedd29be304848f1c5059074a740fa9f6f26b84c5b63e8b29e73dfc270"}, + {file = "cryptography-2.9.2-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:1e4014639d3d73fbc5ceff206049c5a9a849cefd106a49fa7aaaa25cc0ce35cf"}, + {file = "cryptography-2.9.2-cp35-abi3-macosx_10_9_x86_64.whl", hash = "sha256:96c080ae7118c10fcbe6229ab43eb8b090fccd31a09ef55f83f690d1ef619a1d"}, + {file = "cryptography-2.9.2-cp35-abi3-manylinux1_x86_64.whl", hash = "sha256:e993468c859d084d5579e2ebee101de8f5a27ce8e2159959b6673b418fd8c785"}, + {file = "cryptography-2.9.2-cp35-abi3-manylinux2010_x86_64.whl", hash = "sha256:88c881dd5a147e08d1bdcf2315c04972381d026cdb803325c03fe2b4a8ed858b"}, + {file = "cryptography-2.9.2-cp35-cp35m-win32.whl", hash = "sha256:651448cd2e3a6bc2bb76c3663785133c40d5e1a8c1a9c5429e4354201c6024ae"}, + {file = "cryptography-2.9.2-cp35-cp35m-win_amd64.whl", hash = "sha256:726086c17f94747cedbee6efa77e99ae170caebeb1116353c6cf0ab67ea6829b"}, + {file = "cryptography-2.9.2-cp36-cp36m-win32.whl", hash = "sha256:091d31c42f444c6f519485ed528d8b451d1a0c7bf30e8ca583a0cac44b8a0df6"}, + {file = "cryptography-2.9.2-cp36-cp36m-win_amd64.whl", hash = "sha256:bb1f0281887d89617b4c68e8db9a2c42b9efebf2702a3c5bf70599421a8623e3"}, + {file = "cryptography-2.9.2-cp37-cp37m-win32.whl", hash = "sha256:18452582a3c85b96014b45686af264563e3e5d99d226589f057ace56196ec78b"}, + {file = "cryptography-2.9.2-cp37-cp37m-win_amd64.whl", hash = "sha256:22e91636a51170df0ae4dcbd250d318fd28c9f491c4e50b625a49964b24fe46e"}, + {file = "cryptography-2.9.2-cp38-cp38-win32.whl", hash = "sha256:844a76bc04472e5135b909da6aed84360f522ff5dfa47f93e3dd2a0b84a89fa0"}, + {file = "cryptography-2.9.2-cp38-cp38-win_amd64.whl", hash = "sha256:1dfa985f62b137909496e7fc182dac687206d8d089dd03eaeb28ae16eec8e7d5"}, + {file = "cryptography-2.9.2.tar.gz", hash = "sha256:a0c30272fb4ddda5f5ffc1089d7405b7a71b0b0f51993cb4e5dbb4590b2fc229"}, +] docutils = [ {file = "docutils-0.16-py2.py3-none-any.whl", hash = "sha256:0c5b78adfbf7762415433f5515cd5c9e762339e23369dbe8000d84a4bf4ab3af"}, {file = "docutils-0.16.tar.gz", hash = "sha256:c2de3a60e9e7d07be26b7f2b00ca0309c207e06c100f9cc2a94931fc75a478fc"}, @@ -1383,8 +1509,8 @@ hpack = [ {file = "hpack-3.0.0.tar.gz", hash = "sha256:8eec9c1f4bfae3408a3f30500261f7e6a65912dc138526ea054f9ad98892e9d2"}, ] hstspreload = [ - {file = "hstspreload-2020.6.9-py3-none-any.whl", hash = "sha256:697987b7e849f315e5c4625cab662b390e991e7ef951884aa6013106c9c48000"}, - {file = "hstspreload-2020.6.9.tar.gz", hash = "sha256:1534715db2f5224debb605a82e3f79ee9f891031b748cdcf0441eb672d5f3aa2"}, + {file = "hstspreload-2020.6.16-py3-none-any.whl", hash = "sha256:c3a57dbc6abc898f55c791f37bb3ead338ca96b4d1c446cf58eea4e0c10bef4a"}, + {file = "hstspreload-2020.6.16.tar.gz", hash = "sha256:06a634aa0be9a51560be8dccfeceddeba6a2f0a9273433d778c93cdaf93c86d0"}, ] httptools = [ {file = "httptools-0.1.1-cp35-cp35m-macosx_10_13_x86_64.whl", hash = "sha256:a2719e1d7a84bb131c4f1e0cb79705034b48de6ae486eb5297a139d6a3296dce"}, @@ -1480,6 +1606,11 @@ markupsafe = [ {file = "MarkupSafe-1.1.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:ba59edeaa2fc6114428f1637ffff42da1e311e29382d81b339c1817d37ec93c6"}, {file = "MarkupSafe-1.1.1-cp37-cp37m-win32.whl", hash = "sha256:b00c1de48212e4cc9603895652c5c410df699856a2853135b3967591e4beebc2"}, {file = "MarkupSafe-1.1.1-cp37-cp37m-win_amd64.whl", hash = "sha256:9bf40443012702a1d2070043cb6291650a0841ece432556f784f004937f0f32c"}, + {file = "MarkupSafe-1.1.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6788b695d50a51edb699cb55e35487e430fa21f1ed838122d722e0ff0ac5ba15"}, + {file = "MarkupSafe-1.1.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:cdb132fc825c38e1aeec2c8aa9338310d29d337bebbd7baa06889d09a60a1fa2"}, + {file = "MarkupSafe-1.1.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:13d3144e1e340870b25e7b10b98d779608c02016d5184cfb9927a9f10c689f42"}, + {file = "MarkupSafe-1.1.1-cp38-cp38-win32.whl", hash = "sha256:596510de112c685489095da617b5bcbbac7dd6384aeebeda4df6025d0256a81b"}, + {file = "MarkupSafe-1.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:e8313f01ba26fbbe36c7be1966a7b7424942f670f38e666995b88d012765b9be"}, {file = "MarkupSafe-1.1.1.tar.gz", hash = "sha256:29872e92839765e546828bb7754a68c418d927cd064fd4708fab9fe9c8bb116b"}, ] more-itertools = [ @@ -1525,6 +1656,12 @@ mypy-extensions = [ {file = "mypy_extensions-0.4.3-py2.py3-none-any.whl", hash = "sha256:090fedd75945a69ae91ce1303b5824f428daf5a028d2f6ab8a299250a846f15d"}, {file = "mypy_extensions-0.4.3.tar.gz", hash = "sha256:2d82818f5bb3e369420cb3c4060a7970edba416647068eb4c5343488a6c604a8"}, ] +mysqlclient = [ + {file = "mysqlclient-1.4.6-cp36-cp36m-win_amd64.whl", hash = "sha256:4c82187dd6ab3607150fbb1fa5ef4643118f3da122b8ba31c3149ddd9cf0cb39"}, + {file = "mysqlclient-1.4.6-cp37-cp37m-win_amd64.whl", hash = "sha256:9e6080a7aee4cc6a06b58b59239f20f1d259c1d2fddf68ddeed242d2311c7087"}, + {file = "mysqlclient-1.4.6-cp38-cp38-win_amd64.whl", hash = "sha256:f646f8d17d02be0872291f258cce3813497bc7888cd4712a577fd1e719b2f213"}, + {file = "mysqlclient-1.4.6.tar.gz", hash = "sha256:f3fdaa9a38752a3b214a6fe79d7cae3653731a53e577821f9187e67cbecb2e16"}, +] packaging = [ {file = "packaging-20.4-py2.py3-none-any.whl", hash = "sha256:998416ba6962ae7fbd6596850b80e17859a5753ba17c32284f67bfff33784181"}, {file = "packaging-20.4.tar.gz", hash = "sha256:4357f74f47b9c12db93624a82154e9b120fa8293699949152b22065d556079f8"}, @@ -1588,13 +1725,21 @@ psycopg2-binary = [ {file = "psycopg2_binary-2.8.5-cp38-cp38-win_amd64.whl", hash = "sha256:fa466306fcf6b39b8a61d003123d442b23707d635a5cb05ac4e1b62cc79105cd"}, ] py = [ - {file = "py-1.8.1-py2.py3-none-any.whl", hash = "sha256:c20fdd83a5dbc0af9efd622bee9a5564e278f6380fffcacc43ba6f43db2813b0"}, - {file = "py-1.8.1.tar.gz", hash = "sha256:5e27081401262157467ad6e7f851b7aa402c5852dbcb3dae06768434de5752aa"}, + {file = "py-1.8.2-py2.py3-none-any.whl", hash = "sha256:a673fa23d7000440cc885c17dbd34fafcb7d7a6e230b29f6766400de36a33c44"}, + {file = "py-1.8.2.tar.gz", hash = "sha256:f3b3a4c36512a4c4f024041ab51866f11761cc169670204b235f6b20523d4e6b"}, +] +pycparser = [ + {file = "pycparser-2.20-py2.py3-none-any.whl", hash = "sha256:7582ad22678f0fcd81102833f60ef8d0e57288b6b5fb00323d101be910e35705"}, + {file = "pycparser-2.20.tar.gz", hash = "sha256:2d475327684562c3a96cc71adf7dc8c4f0565175cf86b6d7a404ff4c771f15f0"}, ] pygments = [ {file = "Pygments-2.6.1-py3-none-any.whl", hash = "sha256:ff7a40b4860b727ab48fad6360eb351cc1b33cbf9b15a0f689ca5353e9463324"}, {file = "Pygments-2.6.1.tar.gz", hash = "sha256:647344a061c249a3b74e230c739f434d7ea4d8b1d5f3721bc0f3558049b38f44"}, ] +pymysql = [ + {file = "PyMySQL-0.9.2-py2.py3-none-any.whl", hash = "sha256:95f057328357e0e13a30e67857a8c694878b0175797a9a203ee7adbfb9b1ec5f"}, + {file = "PyMySQL-0.9.2.tar.gz", hash = "sha256:9ec760cbb251c158c19d6c88c17ca00a8632bac713890e465b2be01fdc30713f"}, +] pyparsing = [ {file = "pyparsing-2.4.7-py2.py3-none-any.whl", hash = "sha256:ef9d7589ef3c200abe66653d3f1ab1033c3c419ae9b9bdb1240a85b024efc88b"}, {file = "pyparsing-2.4.7.tar.gz", hash = "sha256:c203ec8783bf771a155b207279b9bccb8dea02d8f0c9e5f8ead507bc3246ecc1"}, @@ -1660,8 +1805,8 @@ regex = [ {file = "regex-2020.6.8.tar.gz", hash = "sha256:e9b64e609d37438f7d6e68c2546d2cb8062f3adb27e6336bc129b51be20773ac"}, ] requests = [ - {file = "requests-2.23.0-py2.py3-none-any.whl", hash = "sha256:43999036bfa82904b6af1d99e4882b560e5e2c68e5c4b0aa03b655f3d7d73fee"}, - {file = "requests-2.23.0.tar.gz", hash = "sha256:b3f43d496c6daba4493e7c431722aeb7dbc6288f52a6e04e7b6023b0247817e6"}, + {file = "requests-2.24.0-py2.py3-none-any.whl", hash = "sha256:fe75cc94a9443b9246fc7049224f75604b113c36acb93f87b80ed42c44cbb898"}, + {file = "requests-2.24.0.tar.gz", hash = "sha256:b3559a131db72c33ee969480840fff4bb6dd111de7dd27c8ee1f820f4f00231b"}, ] rfc3986 = [ {file = "rfc3986-1.4.0-py2.py3-none-any.whl", hash = "sha256:af9147e9aceda37c91a05f4deb128d4b4b49d6b199775fd2d2927768abdc8f50"}, diff --git a/pyproject.toml b/pyproject.toml index 90160fe8..c5ba55ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ classifiers = [ python = "^3.5" asyncpg = ">=0.18,<1.0" SQLAlchemy = ">=1.3,<1.4" +mysqlclient = "^1.4" # compatibility contextvars = { version = "^2.4", python = "<3.7" } @@ -37,6 +38,7 @@ gino-aiohttp = { version = "^0.1.0", optional = true, python = "^3.5.3" } gino-tornado = { version = "^0.1.0", optional = true, python = "^3.5.2" } gino-sanic = { version = "^0.1.0", optional = true, python = "^3.6" } gino-quart = { version = "^0.1.0", optional = true, python = "^3.7" } +aiomysql = "^0.0.20" [tool.poetry.extras] starlette = ["gino-starlette"] @@ -68,6 +70,8 @@ sphinx-intl = {extras = ["transifex"], version = "^2.0.1"} [tool.poetry.plugins."sqlalchemy.dialects"] "postgresql.asyncpg" = "gino.dialects.asyncpg:AsyncpgDialect" "asyncpg" = "gino.dialects.asyncpg:AsyncpgDialect" +"mysql.aiomysql" = "gino.dialects.aiomysql:AiomysqlDialect" +"aiomysql" = "gino.dialects.aiomysql:AiomysqlDialect" [build-system] requires = ["poetry>=1.0"] diff --git a/pytest.ini b/pytest.ini index 5ee64771..18057784 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,2 +1,2 @@ [pytest] -testpaths = tests +testpaths = tests mysql_tests diff --git a/src/gino/__init__.py b/src/gino/__init__.py index e4f35ab1..54a48c33 100644 --- a/src/gino/__init__.py +++ b/src/gino/__init__.py @@ -19,6 +19,8 @@ def create_engine(*args, **kwargs): * **Pre-bake** immediately when connected to the database (default). * No **pre-bake** but create prepared statements lazily when needed for the first time. + + Note: ``prebake`` has no effect in aiomysql """ from sqlalchemy import create_engine diff --git a/src/gino/crud.py b/src/gino/crud.py index 603d3abf..5a40bd84 100644 --- a/src/gino/crud.py +++ b/src/gino/crud.py @@ -7,6 +7,7 @@ from . import json_support from .declarative import Model, InvertDict +from .engine import GinoConnection from .exceptions import NoSuchRowError from .loader import AliasLoader, ModelLoader @@ -139,15 +140,20 @@ async def apply(self, bind=None, timeout=DEFAULT): updates[sa.cast(prop.name, sa.Unicode)] = value for prop_name, updates in json_updates.items(): prop = getattr(cls, prop_name) - from .dialects.asyncpg import JSONB + from .dialects.asyncpg import JSONB as psql_JSONB + from .dialects.aiomysql import JSON as mysql_JSON - if isinstance(prop.type, JSONB): + if isinstance(prop.type, psql_JSONB): if self._literal: values[prop_name] = prop.concat(updates) else: values[prop_name] = prop.concat( sa.func.jsonb_build_object(*itertools.chain(*updates.items())) ) + elif isinstance(prop.type, mysql_JSON): + values[prop_name] = sa.func.json_merge_patch( + prop, sa.func.json_object(*itertools.chain(*updates.items())) + ) else: raise TypeError( "{} is not supported to update json " @@ -162,16 +168,11 @@ async def apply(self, bind=None, timeout=DEFAULT): type(self._instance) .update.where(self._locator,) .values(**self._instance._get_sa_values(values),) - .returning(*[getattr(cls, key) for key in values],) .execution_options(**opts) ) - if bind is None: - bind = cls.__metadata__.bind - row = await bind.first(clause) - if not row: - raise NoSuchRowError() - for k, v in row.items(): - self._instance.__values__[self._instance._column_name_map.invert_get(k)] = v + await _query_and_update( + bind, self._instance, clause, [getattr(cls, key) for key in values], opts + ) for prop in self._props: prop.reload(self._instance) return self @@ -465,18 +466,12 @@ async def _create(self, bind=None, timeout=DEFAULT): opts = dict(return_model=False, model=cls) if timeout is not DEFAULT: opts["timeout"] = timeout - # noinspection PyArgumentList q = ( cls.__table__.insert() .values(**self._get_sa_values(self.__values__)) - .returning(*cls) .execution_options(**opts) ) - if bind is None: - bind = cls.__metadata__.bind - row = await bind.first(q) - for k, v in row.items(): - self.__values__[self._column_name_map.invert_get(k)] = v + await _query_and_update(bind, self, q, list(iter(cls)), opts) self.__profile__ = None return self @@ -784,3 +779,91 @@ def __iter__(self): def __call__(self, *args, **kwargs): return self._model(*args, **kwargs) + + +async def _query_and_update(bind, item, query, cols, execution_opts): + cls = type(item) + if bind is None: + bind = cls.__metadata__.bind + # noinspection PyProtectedMember + if bind._dialect.support_returning: + # noinspection PyArgumentList + query = query.returning(*cols) + + async def _execute_and_fetch(conn, query): + context, row = await conn._first_with_context(query) + # For DBMS like MySQL that doesn't support returning inserted or modified + # rows, a workaround is applied to infer necessary data to query from the + # database. This is not able to cover all cases, especially for those + # statements that the end results are not exactly the same as in the queries. + # One example is the DATETIME type in MySQL. By default, inserted date are + # rounded to seconds. This is not visible to the engine. + if not bind._dialect.support_returning: + if context.isinsert: + table = context.compiled.statement.table + key_getter = context.compiled._key_getters_for_crud_column[2] + compiled_params = context.compiled_parameters[0] + last_row_id = context.get_lastrowid() + if last_row_id is not None or table.primary_key: + lookup_conds = [ + c == last_row_id + if c is table._autoincrement_column + else c + == _cast_json(c, compiled_params.get(key_getter(c), None)) + for c in table.primary_key + ] + else: + lookup_conds = [ + c == _cast_json(c, compiled_params.get(key_getter(c), None)) + for c in table.columns + ] + query = ( + sa.select(table.columns) + .where(sa.and_(*lookup_conds)) + .execution_options(**execution_opts) + ) + row = await conn.first(query) + elif context.isupdate: + table = context.compiled.statement.table + if len(table.primary_key) > 0: + lookup_conds = [ + c + == _cast_json( + c, item.__values__[item._column_name_map.invert_get(c.name)] + ) + for c in table.primary_key + ] + else: + lookup_conds = [ + c + == _cast_json( + c, item.__values__[item._column_name_map.invert_get(c.name)] + ) + for c in table.columns + ] + query = ( + sa.select(table.columns) + .where(sa.and_(*lookup_conds)) + .execution_options(**execution_opts) + ) + row = await conn.first(query) + return row + + if isinstance(bind, GinoConnection): + row = await _execute_and_fetch(bind, query) + else: + async with bind.acquire(reuse=True) as conn: + row = await _execute_and_fetch(conn, query) + if not row: + raise NoSuchRowError() + for k, v in row.items(): + item.__values__[item._column_name_map.invert_get(k)] = v + + +def _cast_json(column, value): + # FIXME: for MySQL, json string in WHERE clause needs to be cast to JSON type + if isinstance(column.type, sa.JSON) or isinstance( + getattr(column.type, "impl", None), sa.JSON + ): + return sa.cast(value, sa.JSON) + return value diff --git a/src/gino/declarative.py b/src/gino/declarative.py index dc6c1665..e5eb0767 100644 --- a/src/gino/declarative.py +++ b/src/gino/declarative.py @@ -365,8 +365,9 @@ def _init_table(cls, sub_cls): json_col = getattr( sub_cls.__dict__.get(v.prop_name), "column", None ) - if not isinstance(json_col, sa.Column) or not isinstance( - json_col.type, sa.JSON + if not ( + isinstance(json_col, sa.Column) + and isinstance(json_col.type, sa.JSON) ): raise AttributeError( '{} "{}" requires a JSON[B] column "{}" ' diff --git a/src/gino/dialects/aiomysql.py b/src/gino/dialects/aiomysql.py new file mode 100644 index 00000000..9351d235 --- /dev/null +++ b/src/gino/dialects/aiomysql.py @@ -0,0 +1,495 @@ +import asyncio +import inspect +import itertools +import re +import time +import warnings + +import aiomysql +from sqlalchemy import util, exc +from sqlalchemy.dialects.mysql import JSON, ENUM +from sqlalchemy.dialects.mysql.base import ( + MySQLCompiler, + MySQLDialect, + MySQLExecutionContext, +) +from sqlalchemy.sql import sqltypes + +from . import base + +try: + import click +except ImportError: + click = None +JSON_COLTYPE = 245 + +#: Regular expression for :meth:`Cursor.executemany`. +#: executemany only supports simple bulk insert. +#: You can use it to load large dataset. +_RE_INSERT_VALUES = re.compile( + r"\s*((?:INSERT|REPLACE)\s.+\sVALUES?\s+)" + + r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" + + r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z", + re.IGNORECASE | re.DOTALL, +) + +#: Max statement size which :meth:`executemany` generates. +#: +#: Max size of allowed statement is max_allowed_packet - +# packet_header_size. +#: Default value of max_allowed_packet is 1048576. +_MAX_STMT_LENGTH = 1024000 + + +class AiomysqlDBAPI(base.BaseDBAPI): + paramstyle = "format" + + +# noinspection PyAbstractClass +class AiomysqlExecutionContext(base.ExecutionContextOverride, MySQLExecutionContext): + def get_lastrowid(self): + lastrowid = self.cursor.last_row_id + return None if lastrowid == 0 else lastrowid + + def get_affected_rows(self): + return self.cursor.affected_rows + + +class AiomysqlIterator(base.Cursor): + def __init__(self, context, cursor): + self._context = context + self._cursor = cursor + self._queried = False + + def __await__(self): + async def return_self(): + return self + + return return_self().__await__() + + def __aiter__(self): + return self + + async def _init(self): + if not self._queried: + query = self._context.statement + args = self._context.parameters[0] + await self._cursor.execute(query, args) + self._context.cursor._cursor_description = self._cursor.description + self._queried = True + + async def __anext__(self): + await self._init() + row = await asyncio.wait_for(self._cursor.fetchone(), self._context.timeout) + if row is None: + raise StopAsyncIteration + return self._context.process_rows([row])[0] + + async def many(self, n, *, timeout=base.DEFAULT): + await self._init() + if timeout is base.DEFAULT: + timeout = self._context.timeout + rows = await asyncio.wait_for(self._cursor.fetchmany(n), timeout) + if not rows: + return [] + return self._context.process_rows(rows) + + async def next(self, *, timeout=base.DEFAULT): + try: + return await self.__anext__() + except StopAsyncIteration: + return None + + async def forward(self, n, *, timeout=base.DEFAULT): + await self._init() + if timeout is base.DEFAULT: + timeout = self._context.timeout + await asyncio.wait_for(self._cursor.scroll(n, mode="relative"), timeout) + + +class DBAPICursor(base.DBAPICursor): + def __init__(self, dbapi_conn): + self._conn = dbapi_conn + self._cursor_description = None + self._status = None + self.last_row_id = None + self.affected_rows = 0 + + async def prepare(self, context, clause=None): + raise Exception("aiomysql doesn't support prepare") + + async def async_execute(self, query, timeout, args, limit=0, many=False): + if timeout is None: + conn = await self._conn.acquire(timeout=timeout) + else: + before = time.monotonic() + conn = await self._conn.acquire(timeout=timeout) + after = time.monotonic() + timeout -= after - before + + if not many: + return await self._async_execute(conn, query, timeout, args) + + return await asyncio.wait_for( + self._async_executemany(conn, query, args), timeout=timeout + ) + + async def execute_baked(self, baked_query, timeout, args, one): + # TODO: use prepare when it's supported + return await self.async_execute(baked_query.sql, timeout, args) + + async def _async_execute(self, conn, query, timeout, args): + if args is not None: + query = query % _escape_args(args, conn) + await asyncio.wait_for(conn.query(query), timeout=timeout) + # noinspection PyProtectedMember + result = conn._result + self._cursor_description = result.description + self._status = result.affected_rows + self.last_row_id = result.insert_id + self.affected_rows = result.affected_rows + return result.rows + + async def _async_executemany(self, conn, query, args): + m = _RE_INSERT_VALUES.match(query) + if m: + q_prefix = m.group(1) + q_values = m.group(2).rstrip() + q_postfix = m.group(3) or "" + return await self._do_execute_many( + conn, q_prefix, q_values, q_postfix, args + ) + else: + rows = 0 + for arg in args: + await self.execute(query, arg) + rows += self.affected_rows + self.affected_rows = rows + return None + + async def _do_execute_many(self, conn, prefix, values, postfix, args): + escape = _escape_args + if isinstance(prefix, str): + prefix = prefix.encode(conn.encoding) + if isinstance(postfix, str): + postfix = postfix.encode(conn.encoding) + stmt = bytearray(prefix) + args = iter(args) + v = values % escape(next(args), conn) + if isinstance(v, str): + v = v.encode(conn.encoding, "surrogateescape") + stmt += v + rows = 0 + for arg in args: + v = values % escape(arg, conn) + if isinstance(v, str): + v = v.encode(conn.encoding, "surrogateescape") + if len(stmt) + len(v) + len(postfix) + 1 > _MAX_STMT_LENGTH: + await self._async_execute(conn, stmt + postfix, None, None) + rows += self.affected_rows + stmt = bytearray(prefix) + else: + stmt += b"," + stmt += v + await self._async_execute(conn, stmt + postfix, None, None) + self.affected_rows += rows + + @property + def description(self): + return self._cursor_description + + def get_statusmsg(self): + return self._status + + def iterate(self, context): + # use SSCursor to get server side cursor + return AiomysqlIterator(context, aiomysql.SSCursor(self._conn.raw_connection)) + + +class Pool(base.Pool): + def __init__(self, url, loop, init=None, bakery=None, prebake=True, **kwargs): + self._url = url + self._loop = loop + self._kwargs = kwargs + self._pool = None + self._conn_init = init + self._bakery = bakery + self._prebake = prebake + + async def _init(self): + args = self._kwargs.copy() + args.update( + loop=self._loop, + host=self._url.host, + port=self._url.port, + user=self._url.username, + db=self._url.database, + password=self._url.password, + ) + # aiomysql sets autocommit as False by default, which opposes the MySQL + # default, therefore it's set to None to respect the MySQL configuration + args.setdefault("autocommit", None) + self._pool = await aiomysql.create_pool(**args) + return self + + def __await__(self): + return self._init().__await__() + + @property + def raw_pool(self): + return self._pool + + async def acquire(self, *, timeout=None): + if timeout is None: + conn = await self._pool.acquire() + else: + conn = await asyncio.wait_for(self._pool.acquire(), timeout=timeout) + if self._conn_init is not None: + try: + await self._conn_init(conn) + except: + await self.release(conn) + raise + return conn + + async def release(self, conn): + await self._pool.release(conn) + + async def close(self): + self._pool.close() + await self._pool.wait_closed() + + def repr(self, color): + if color and not click: + warnings.warn("Install click to get colorful repr.", ImportWarning) + + if color and click: + # noinspection PyProtectedMember + return "<{classname} max={max} min={min} cur={cur} use={use}>".format( + classname=click.style( + self._pool.__class__.__module__ + + "." + + self._pool.__class__.__name__, + fg="green", + ), + max=click.style(repr(self._pool.maxsize), fg="cyan"), + min=click.style(repr(self._pool._minsize), fg="cyan"), + cur=click.style(repr(self._pool.size), fg="cyan"), + use=click.style(repr(len(self._pool._used)), fg="cyan"), + ) + else: + # noinspection PyProtectedMember + return "<{classname} max={max} min={min} cur={cur} use={use}>".format( + classname=self._pool.__class__.__module__ + + "." + + self._pool.__class__.__name__, + max=self._pool.maxsize, + min=self._pool._minsize, + cur=self._pool.size, + use=len(self._pool._used), + ) + + +class Transaction(base.Transaction): + def __init__(self, conn, set_isolation=None): + self._conn = conn + self._set_isolation = set_isolation + + @property + def raw_transaction(self): + return self._conn + + async def begin(self): + await self._conn.begin() + if self._set_isolation is not None: + await self._set_isolation(self._conn) + + async def commit(self): + await self._conn.commit() + + async def rollback(self): + await self._conn.rollback() + + +# MySQL doesn't need to create ENUM types like PostgreSQL, do nothing here +class AsyncEnum(ENUM): + async def create_async(self, bind=None, checkfirst=True): + pass + + async def drop_async(self, bind=None, checkfirst=True): + pass + + async def _on_table_create_async(self, target, bind, checkfirst=False, **kw): + pass + + async def _on_table_drop_async(self, target, bind, checkfirst=False, **kw): + pass + + async def _on_metadata_create_async(self, target, bind, checkfirst=False, **kw): + pass + + async def _on_metadata_drop_async(self, target, bind, checkfirst=False, **kw): + pass + + +class GinoNullType(sqltypes.NullType): + def result_processor(self, dialect, coltype): + if coltype == JSON_COLTYPE: + return JSON().result_processor(dialect, coltype) + return super().result_processor(dialect, coltype) + + +# noinspection PyAbstractClass +class AiomysqlDialect(MySQLDialect, base.AsyncDialectMixin): + driver = "aiomysql" + supports_native_decimal = True + dbapi_class = AiomysqlDBAPI + statement_compiler = MySQLCompiler + execution_ctx_cls = AiomysqlExecutionContext + cursor_cls = DBAPICursor + init_kwargs = set( + itertools.chain( + ("bakery", "prebake"), + *[ + inspect.getfullargspec(f).args + for f in [aiomysql.create_pool, aiomysql.connect] + ] + ) + ) - { + "echo" + } # use SQLAlchemy's echo instead + colspecs = util.update_copy( + MySQLDialect.colspecs, + {ENUM: AsyncEnum, sqltypes.Enum: AsyncEnum, sqltypes.NullType: GinoNullType,}, + ) + postfetch_lastrowid = False + support_returning = False + support_prepare = False + + def __init__(self, *args, bakery=None, **kwargs): + self._pool_kwargs = {} + for k in self.init_kwargs: + if k in kwargs: + self._pool_kwargs[k] = kwargs.pop(k) + super().__init__(*args, **kwargs) + self._init_mixin(bakery) + + async def init_pool(self, url, loop, pool_class=None): + if pool_class is None: + pool_class = Pool + return await pool_class( + url, loop, bakery=self._bakery, init=self.on_connect(), **self._pool_kwargs + ) + + # noinspection PyMethodMayBeStatic + def transaction(self, raw_conn, args, kwargs): + _set_isolation = None + if "isolation" in kwargs: + + async def _set_isolation(conn): + await self.set_isolation_level(conn, kwargs["isolation"]) + + return Transaction(raw_conn, _set_isolation) + + def on_connect(self): + if self.isolation_level is not None: + + async def connect(conn): + await self.set_isolation_level(conn, self.isolation_level) + + return connect + else: + return None + + async def set_isolation_level(self, connection, level): + level = level.replace("_", " ") + await self._set_isolation_level(connection, level) + + async def _set_isolation_level(self, connection, level): + if level not in self._isolation_lookup: + raise exc.ArgumentError( + "Invalid value '%s' for isolation_level. " + "Valid isolation levels for %s are %s" + % (level, self.name, ", ".join(self._isolation_lookup)) + ) + cursor = await connection.cursor() + await cursor.execute("SET SESSION TRANSACTION ISOLATION LEVEL %s" % level) + await cursor.execute("COMMIT") + await cursor.close() + + async def get_isolation_level(self, connection): + if self.server_version_info is None: + self.server_version_info = await self._get_server_version_info(connection) + cursor = await connection.cursor() + if self._is_mysql and self.server_version_info >= (5, 7, 20): + await cursor.execute("SELECT @@transaction_isolation") + else: + await cursor.execute("SELECT @@tx_isolation") + row = await cursor.fetchone() + if row is None: + util.warn( + "Could not retrieve transaction isolation level for MySQL " + "connection." + ) + raise NotImplementedError() + val = row[0] + await cursor.close() + if isinstance(val, bytes): + val = val.decode() + return val.upper().replace("-", " ") + + async def _get_server_version_info(self, connection): + # get database server version info explicitly over the wire + # to avoid proxy servers like MaxScale getting in the + # way with their own values, see #4205 + cursor = await connection.cursor() + await cursor.execute("SELECT VERSION()") + val = (await cursor.fetchone())[0] + await cursor.close() + if isinstance(val, bytes): + val = val.decode() + + return self._parse_server_version(val) + + def _parse_server_version(self, val): + version = [] + r = re.compile(r"[.\-]") + for n in r.split(val): + try: + version.append(int(n)) + except ValueError: + mariadb = re.match(r"(.*)(MariaDB)(.*)", n) + if mariadb: + version.extend(g for g in mariadb.groups() if g) + else: + version.append(n) + return tuple(version) + + async def has_table(self, connection, table_name, schema=None): + full_name = ".".join( + self.identifier_preparer._quote_free_identifiers(schema, table_name) + ) + + st = "DESCRIBE %s" % full_name + try: + return await connection.first(st) is not None + except aiomysql.ProgrammingError as e: + if self._extract_error_code(e) == 1146: + return False + raise + + def _extract_error_code(self, exception): + if isinstance(exception.args[0], Exception): + exception = exception.args[0] + return exception.args[0] + + +def _escape_args(args, conn): + if isinstance(args, (tuple, list)): + return tuple(conn.escape(arg) for arg in args) + elif isinstance(args, dict): + return dict((key, conn.escape(val)) for (key, val) in args.items()) + else: + # If it's not a dictionary let's try escaping it anyways. + # Worst case it will throw a Value error + return conn.escape(args) diff --git a/src/gino/dialects/base.py b/src/gino/dialects/base.py index 0b0d630d..01d72628 100644 --- a/src/gino/dialects/base.py +++ b/src/gino/dialects/base.py @@ -93,9 +93,7 @@ async def _do_execute( self.clause, *multiparams, **params ).context if ctx.executemany: - raise ValueError( - "PreparedStatement does not support multiple " "parameters." - ) + raise ValueError("PreparedStatement does not support multiple parameters.") if ctx.statement != self.context.statement: raise AssertionError( "Prepared statement generated different SQL with parameters" @@ -160,10 +158,12 @@ def __init__(self, context): self._context = context async def _iterate(self): - prepared = await self._context.cursor.prepare(self._context) - return prepared.iterate( - *self._context.parameters[0], timeout=self._context.timeout - ) + if self._context.dialect.support_prepare: + prepared = await self._context.cursor.prepare(self._context) + return prepared.iterate( + *self._context.parameters[0], timeout=self._context.timeout + ) + return self._context.cursor.iterate(self._context) async def _get_cursor(self): return await (await self._iterate()) @@ -196,7 +196,9 @@ def __init__(self, context): def context(self): return self._context - async def execute(self, one=False, return_model=True, status=False): + async def execute( + self, one=False, return_model=True, status=False, return_context=False + ): context = self._context param_groups = [] @@ -213,25 +215,26 @@ async def execute(self, one=False, return_model=True, status=False): return await cursor.async_execute( context.statement, context.timeout, param_groups, many=True ) + args = param_groups[0] + if context.baked_query: + rows = await cursor.execute_baked( + context.baked_query, context.timeout, args, one + ) else: - args = param_groups[0] - if context.baked_query: - rows = await cursor.execute_baked( - context.baked_query, context.timeout, args, one - ) + rows = await cursor.async_execute( + context.statement, context.timeout, args, 1 if one else 0 + ) + item = context.process_rows(rows, return_model=return_model) + if one: + if item: + item = item[0] else: - rows = await cursor.async_execute( - context.statement, context.timeout, args, 1 if one else 0 - ) - item = context.process_rows(rows, return_model=return_model) - if one: - if item: - item = item[0] - else: - item = None - if status: - item = cursor.get_statusmsg(), item - return item + item = None + if status: + return cursor.get_statusmsg(), item + if return_context: + return context, item + return item def iterate(self): if self._context.executemany: @@ -291,6 +294,8 @@ def loader(self): return self._compiled_first_opt("loader", None) def process_rows(self, rows, return_model=True): + if not rows: + return [] # noinspection PyUnresolvedReferences rv = rows = super().get_result_proxy().process_rows(rows) loader = self.loader @@ -405,10 +410,20 @@ def _init_baked_query(cls, dialect, connection, dbapi_connection, bq, parameters self.baked_query = bq return self + def get_lastrowid(self): + raise NotImplementedError + + def get_affected_rows(self): + # Note: in MySQL result, affected rows means the number of rows that get + # updated, but not the number of matched rows. + raise NotImplementedError + class AsyncDialectMixin: cursor_cls = DBAPICursor dbapi_class = BaseDBAPI + support_returning = True + support_prepare = True _bakery = None def _init_mixin(self, bakery): @@ -437,7 +452,7 @@ def compile(self, elem, *multiparams, **params): else: return context.statement, context.parameters[0] - async def init_pool(self, url, loop): + async def init_pool(self, url, loop, pool_class=None): raise NotImplementedError def transaction(self, raw_conn, args, kwargs): diff --git a/src/gino/engine.py b/src/gino/engine.py index 28bf7f59..3813a7a3 100644 --- a/src/gino/engine.py +++ b/src/gino/engine.py @@ -348,6 +348,10 @@ async def first(self, clause, *multiparams, **params): result = self._execute(clause, multiparams, params) return await result.execute(one=True) + async def _first_with_context(self, clause, *multiparams, **params): + result = self._execute(clause, multiparams, params) + return await result.execute(one=True, return_context=True) + async def one_or_none(self, clause, *multiparams, **params): """ Runs the given query in database, returns at most one result. diff --git a/src/gino/json_support.py b/src/gino/json_support.py index 79faa505..272fdb53 100644 --- a/src/gino/json_support.py +++ b/src/gino/json_support.py @@ -1,6 +1,7 @@ from datetime import datetime import sqlalchemy as sa +from sqlalchemy.dialects import mysql from .exceptions import UnknownJSONPropertyError @@ -114,12 +115,18 @@ def __hash__(self): class StringProperty(JSONProperty): def make_expression(self, base_exp): - return base_exp.astext + try: + return base_exp.astext + except AttributeError: + return base_exp.cast(sa.String) class DateTimeProperty(JSONProperty): def make_expression(self, base_exp): - return base_exp.astext.cast(sa.DateTime) + try: + return base_exp.astext.cast(sa.DateTime) + except AttributeError: + return sa.func.json_unquote(base_exp).cast(mysql.DATETIME(fsp=6)) def decode(self, val): if val: @@ -134,7 +141,10 @@ def encode(self, val): class IntegerProperty(JSONProperty): def make_expression(self, base_exp): - return base_exp.astext.cast(sa.Integer) + try: + return base_exp.astext.cast(sa.Integer) + except AttributeError: + return base_exp.cast(sa.Integer) def decode(self, val): if val is not None: @@ -149,7 +159,10 @@ def encode(self, val): class BooleanProperty(JSONProperty): def make_expression(self, base_exp): - return base_exp.astext.cast(sa.Boolean) + try: + return base_exp.astext.cast(sa.Boolean) + except AttributeError: + return base_exp.cast(sa.Boolean) def decode(self, val): if val is not None: diff --git a/src/gino/strategies.py b/src/gino/strategies.py index eeb1cd85..13bbb83e 100644 --- a/src/gino/strategies.py +++ b/src/gino/strategies.py @@ -29,6 +29,9 @@ async def create(self, name_or_url, loop=None, **kwargs): if u.drivername in {"postgresql", "postgres"}: u = copy(u) u.drivername = "postgresql+asyncpg" + elif u.drivername in {"mysql"}: + u = copy(u) + u.drivername = "mysql+aiomysql" dialect_cls = u.get_dialect() @@ -61,6 +64,7 @@ async def create(self, name_or_url, loop=None, **kwargs): # all kwargs should be consumed if kwargs: + await pool.close() raise TypeError( "Invalid argument(s) %s sent to create_engine(), " "using configuration %s/%s. Please check that the " diff --git a/tests/test_engine.py b/tests/test_engine.py index b3439eec..b4a9a633 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -288,6 +288,7 @@ async def acquire_failed(*args, **kwargs): await blocker assert qsize(engine) == init_size + await engine.close() async def test_release(engine): diff --git a/tests/test_json.py b/tests/test_json.py index 3ac55ec9..1c646a89 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -330,7 +330,7 @@ class PropsTest(db.Model): await PropsTest.gino.create() try: - t = await PropsTest.create(bool=True, bool1=True,) + await PropsTest.create(bool=True, bool1=True,) profile = await bind.scalar("SELECT profile FROM props_test_291") assert isinstance(profile, dict) profile1 = await bind.scalar("SELECT profile1 FROM props_test_291")