diff --git a/sqlalchemy_utils/view.py b/sqlalchemy_utils/view.py index 7fe4f043..e11d67c6 100644 --- a/sqlalchemy_utils/view.py +++ b/sqlalchemy_utils/view.py @@ -6,24 +6,27 @@ class CreateView(DDLElement): - def __init__(self, name, selectable, materialized=False): - self.name = name + def __init__(self, table, selectable, materialized=False): + self.table = table self.selectable = selectable self.materialized = materialized + # self.schema = schema @compiler.compiles(CreateView) def compile_create_materialized_view(element, compiler, **kw): return 'CREATE {}VIEW {} AS {}'.format( 'MATERIALIZED ' if element.materialized else '', - compiler.dialect.identifier_preparer.quote(element.name), + # compiler.dialect.identifier_preparer.quote(element.name), + compiler.dialect.identifier_preparer.format_table( + element.table, use_schema=True), compiler.sql_compiler.process(element.selectable, literal_binds=True), ) class DropView(DDLElement): - def __init__(self, name, materialized=False, cascade=True): - self.name = name + def __init__(self, table, materialized=False, cascade=True): + self.table = table self.materialized = materialized self.cascade = cascade @@ -32,7 +35,9 @@ def __init__(self, name, materialized=False, cascade=True): def compile_drop_materialized_view(element, compiler, **kw): return 'DROP {}VIEW IF EXISTS {} {}'.format( 'MATERIALIZED ' if element.materialized else '', - compiler.dialect.identifier_preparer.quote(element.name), + # compiler.dialect.identifier_preparer.quote(element.name), + compiler.dialect.identifier_preparer.format_table( + element.table, use_schema=True), 'CASCADE' if element.cascade else '' ) @@ -43,6 +48,7 @@ def create_table_from_selectable( indexes=None, metadata=None, aliases=None, + schema=None, **kwargs ): if indexes is None: @@ -60,7 +66,7 @@ def create_table_from_selectable( ) for c in get_columns(selectable) ] + indexes - table = sa.Table(name, metadata, *args, **kwargs) + table = sa.Table(name, metadata, *args, **kwargs, schema=schema) if not any([c.primary_key for c in get_columns(selectable)]): table.append_constraint( @@ -74,7 +80,8 @@ def create_materialized_view( selectable, metadata, indexes=None, - aliases=None + aliases=None, + schema=None, ): """ Create a view on a given metadata @@ -87,6 +94,7 @@ def create_materialized_view( :param aliases: An optional dictionary containing with keys as column names and values as column aliases. + :param schema: optinal the schema name for the view Same as for ``create_view`` except that a ``CREATE MATERIALIZED VIEW`` statement is emitted instead of a ``CREATE VIEW``. @@ -97,13 +105,14 @@ def create_materialized_view( selectable=selectable, indexes=indexes, metadata=None, - aliases=aliases + aliases=aliases, + schema=schema ) sa.event.listen( metadata, 'after_create', - CreateView(name, selectable, materialized=True) + CreateView(table, selectable, materialized=True) ) @sa.event.listens_for(metadata, 'after_create') @@ -114,7 +123,7 @@ def create_indexes(target, connection, **kw): sa.event.listen( metadata, 'before_drop', - DropView(name, materialized=True) + DropView(table, materialized=True) ) return table @@ -123,6 +132,8 @@ def create_view( name, selectable, metadata, + schema=None, + # indexes=None, Does non-materialized views allow index creation?? cascade_on_drop=True ): """ Create a view on a given metadata @@ -132,6 +143,8 @@ def create_view( :param metadata: An SQLAlchemy Metadata instance that stores the features of the database being described. + :param schema: optinal the schema name for the view + The process for creating a view is similar to the standard way that a table is constructed, except that a selectable is provided instead of @@ -147,10 +160,11 @@ def create_view( Column('name', String), Column('fullname', String), Column('premium_user', Boolean, default=False), + schema=None ) premium_members = select([users]).where(users.c.premium_user == True) - create_view('premium_users', premium_members, metadata) + create_view('premium_users', premium_members, metadata,) metadata.create_all(engine) # View is created at this point @@ -158,39 +172,63 @@ def create_view( table = create_table_from_selectable( name=name, selectable=selectable, + schema=schema, + # indexes=indexes,??? metadata=None ) - sa.event.listen(metadata, 'after_create', CreateView(name, selectable)) + sa.event.listen(metadata, 'after_create', CreateView(table, selectable)) @sa.event.listens_for(metadata, 'after_create') def create_indexes(target, connection, **kw): + # Does non-materialized views allow index creation?? for idx in table.indexes: idx.create(connection) sa.event.listen( metadata, 'before_drop', - DropView(name, cascade=cascade_on_drop) + DropView(table, cascade=cascade_on_drop) ) return table -def refresh_materialized_view(session, name, concurrently=False): +def refresh_materialized_view(session, view, concurrently=False): """ Refreshes an already existing materialized view :param session: An SQLAlchemy Session instance. - :param name: The name of the materialized view to refresh. + :param view: The view to refresh. :param concurrently: Optional flag that causes the ``CONCURRENTLY`` parameter to be specified when the materialized view is refreshed. + + + example (flask_sqlalchemy) ORM: + ArticleMV(db.Model): + __table__ = create_materialized_view( + name = 'article-mv', + selectable = db.select(...), + schema = 'main' + ) + @classmethod + def refresh_view(cls, concurrently=False): + refresh_materialized_view(db.session, cls, concurrently) + + User.refresh_view() + >SQL: REFRESH MATERIALIZED VIEW main.article-mv """ # Since session.execute() bypasses autoflush, we must manually flush in # order to include newly-created/modified objects in the refresh. + # session.bind.engine.dialect.identifier_preparer + # do no accept str as a param, it schould be the table + session.flush() session.execute( 'REFRESH MATERIALIZED VIEW {}{}'.format( 'CONCURRENTLY ' if concurrently else '', - session.bind.engine.dialect.identifier_preparer.quote(name) + # session.bind.engine.dialect.identifier_preparer.quote(name) + session.bind.engine.dialect.identifier_preparer.format_table( + view.__table__, use_schema=True) ) ) + session.commit() # needed to persist changes in the materialized view diff --git a/tests/test_views.py b/tests/test_views.py index 4bd8b0a5..c0279c52 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -94,7 +94,7 @@ def test_refresh_materialized_view( ) session.add(article) session.commit() - refresh_materialized_view(session, 'article-mv') + refresh_materialized_view(session, ArticleMV) materialized = session.query(ArticleMV).first() assert materialized.article_name == 'Some article' assert materialized.author_name == 'Some user' diff --git a/tests/test_views_schema.py b/tests/test_views_schema.py new file mode 100644 index 00000000..21bae3d3 --- /dev/null +++ b/tests/test_views_schema.py @@ -0,0 +1,202 @@ +import pytest +import sqlalchemy as sa + +from sqlalchemy_utils import ( + create_materialized_view, + create_view, + refresh_materialized_view +) + + +@pytest.fixture +def Article(Base, User): + class Article(Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String) + author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id)) + author = sa.orm.relationship(User) + return Article + + +@pytest.fixture +def User(Base): + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String) + return User + + +@pytest.fixture +def ArticleMV(Base, Article, User, engine): + # def create_schema(engine): + if not engine.dialect.has_schema(engine, 'main'): + engine.execute(sa.schema.CreateSchema('main')) + + class ArticleMV(Base): + __table__ = create_materialized_view( + name='article-mv', + selectable=sa.select( + [ + Article.id, + Article.name, + User.id.label('author_id'), + User.name.label('author_name') + ], + from_obj=( + Article.__table__ + .join(User, Article.author_id == User.id) + ) + ), + aliases={'name': 'article_name'}, + metadata=Base.metadata, + schema='main', + indexes=[sa.Index('article-mv_id_idx', 'id')] + ) + # __table_args__ = {"schema": "main"} + return ArticleMV + + +@pytest.fixture +def ArticleView(Base, Article, User): + class ArticleView(Base): + __table__ = create_view( + name='article-view', + selectable=sa.select( + [ + Article.id, + Article.name, + User.id.label('author_id'), + User.name.label('author_name') + ], + from_obj=( + Article.__table__ + .join(User, Article.author_id == User.id) + ) + ), + schema='main', + metadata=Base.metadata + ) + # __table_args__ = {"schema": "main"} + + return ArticleView + + +@pytest.fixture +def init_models(ArticleMV, ArticleView): + pass + + +@pytest.mark.usefixtures('postgresql_dsn') +class TestMaterializedViews: + + # def create_schema(engine): + # if not engine.dialect.has_schema(engine, 'main'): + # engine.exeute(sa.schema.CreateSchema('main')) + + def test_refresh_materialized_view( + self, + session, + Article, + User, + ArticleMV + ): + article = Article( + name='Some article', + author=User(name='Some user') + ) + session.add(article) + session.commit() + refresh_materialized_view(session, ArticleMV) + materialized = session.query(ArticleMV).first() + assert materialized.article_name == 'Some article' + assert materialized.author_name == 'Some user' + + def test_querying_view( + self, + session, + Article, + User, + # ArticleMV + ArticleView + ): + article = Article( + name='Some article', + author=User(name='Some user') + ) + session.add(article) + session.commit() + row = session.query(ArticleView).first() + assert row.name == 'Some article' + assert row.author_name == 'Some user' + + def drop_view(self, engine, ArticleMV, ArticleView): + ArticleView.__table__.drop(engine) + ArticleMV.__table__.drop(engine) + if engine.dialect.has_schema(engine, 'main'): + engine.execute(sa.schema.DropSchema('main')) + + +class TrivialViewTestCases: + def life_cycle( + self, + engine, + metadata, + column, + cascade_on_drop + ): + __table__ = create_view( + name='trivial_view', + selectable=sa.select([column]), + metadata=metadata, + cascade_on_drop=cascade_on_drop + ) + __table__.create(engine) + __table__.drop(engine) + + +class SupportsCascade(TrivialViewTestCases): + def test_life_cycle_cascade( + self, + connection, + engine, + Base, + User + ): + self.life_cycle(engine, Base.metadata, User.id, cascade_on_drop=True) + + +class DoesntSupportCascade(SupportsCascade): + @pytest.mark.xfail + def test_life_cycle_cascade(self, *args, **kwargs): + super(DoesntSupportCascade, self).test_life_cycle_cascade( + *args, + **kwargs + ) + + +class SupportsNoCascade(TrivialViewTestCases): + def test_life_cycle_no_cascade( + self, + connection, + engine, + Base, + User + ): + self.life_cycle(engine, Base.metadata, User.id, cascade_on_drop=False) + + +@pytest.mark.usefixtures('postgresql_dsn') +class TestPostgresTrivialView(SupportsCascade, SupportsNoCascade): + pass + + +# @pytest.mark.usefixtures('mysql_dsn') +# class TestMySqlTrivialView(SupportsCascade, SupportsNoCascade): +# pass +# +# +# @pytest.mark.usefixtures('sqlite_none_database_dsn') +# class TestSqliteTrivialView(DoesntSupportCascade, SupportsNoCascade): +# pass