Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 32 additions & 22 deletions api/models/account.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import enum
import json
from datetime import datetime
from typing import Optional, cast
from typing import Optional

import sqlalchemy as sa
from flask_login import UserMixin
from sqlalchemy import DateTime, String, func, select
from sqlalchemy.orm import Mapped, mapped_column, reconstructor
from sqlalchemy.orm import Mapped, Session, mapped_column, reconstructor

from models.base import Base

Expand Down Expand Up @@ -118,10 +118,24 @@ def current_tenant(self):

@current_tenant.setter
def current_tenant(self, tenant: "Tenant"):
ta = db.session.scalar(select(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).limit(1))
if ta:
self.role = TenantAccountRole(ta.role)
self._current_tenant = tenant
with Session(db.engine, expire_on_commit=False) as session:
tenant_join_query = select(TenantAccountJoin).where(
TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == self.id
)
tenant_join = session.scalar(tenant_join_query)
tenant_query = select(Tenant).where(Tenant.id == tenant.id)
# TODO: A workaround to reload the tenant with `expire_on_commit=False`, allowing
# access to it after the session has been closed.
# This prevents `DetachedInstanceError` when accessing the tenant outside
# the session's lifecycle.
# (The `tenant` argument is typically loaded by `db.session` without the
# `expire_on_commit=False` flag, meaning its lifetime is tied to the web
# request's lifecycle.)
tenant_reloaded = session.scalars(tenant_query).one()

if tenant_join:
self.role = TenantAccountRole(tenant_join.role)
self._current_tenant = tenant_reloaded
return
self._current_tenant = None

Expand All @@ -130,23 +144,19 @@ def current_tenant_id(self) -> str | None:
return self._current_tenant.id if self._current_tenant else None

def set_tenant_id(self, tenant_id: str):
tenant_account_join = cast(
tuple[Tenant, TenantAccountJoin],
(
db.session.query(Tenant, TenantAccountJoin)
.where(Tenant.id == tenant_id)
.where(TenantAccountJoin.tenant_id == Tenant.id)
.where(TenantAccountJoin.account_id == self.id)
.one_or_none()
),
query = (
select(Tenant, TenantAccountJoin)
.where(Tenant.id == tenant_id)
.where(TenantAccountJoin.tenant_id == Tenant.id)
.where(TenantAccountJoin.account_id == self.id)
)

if not tenant_account_join:
return

tenant, join = tenant_account_join
self.role = TenantAccountRole(join.role)
self._current_tenant = tenant
with Session(db.engine, expire_on_commit=False) as session:
tenant_account_join = session.execute(query).first()
if not tenant_account_join:
return
tenant, join = tenant_account_join
self.role = TenantAccountRole(join.role)
self._current_tenant = tenant

@property
def current_role(self):
Expand Down