Skip to content

Commit 1098d10

Browse files
authored
Delete OIDC account when deleting user (#728)
1 parent a4d4793 commit 1098d10

File tree

1 file changed

+43
-18
lines changed

1 file changed

+43
-18
lines changed

gramps_webapi/auth/__init__.py

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,18 @@ def get_tree(guid: str) -> Optional[str]:
141141

142142

143143
def delete_user(name: str) -> None:
144-
"""Delete an existing user."""
144+
"""Delete an existing user and their associated OIDC accounts."""
145145
query = user_db.session.query(User) # pylint: disable=no-member
146146
user = query.filter_by(name=name).scalar()
147147
if user is None:
148148
raise ValueError(f"User {name} not found")
149+
150+
# Manually delete associated OIDC accounts first.
151+
# This is needed because SQLite does not enforce foreign key constraints by default.
152+
user_db.session.query(OIDCAccount).filter_by(
153+
user_id=user.id
154+
).delete() # pylint: disable=no-member
155+
149156
user_db.session.delete(user) # pylint: disable=no-member
150157
user_db.session.commit() # pylint: disable=no-member
151158

@@ -196,7 +203,9 @@ def get_pwhash(username: str) -> str:
196203
return user.pwhash
197204

198205

199-
def _get_user_detail(user, include_guid: bool = False, include_oidc_accounts: bool = False):
206+
def _get_user_detail(
207+
user, include_guid: bool = False, include_oidc_accounts: bool = False
208+
):
200209
details = {
201210
"name": user.name,
202211
"email": user.email,
@@ -228,7 +237,10 @@ def get_user_details(username: str) -> Optional[Dict[str, Any]]:
228237

229238

230239
def get_all_user_details(
231-
tree: Optional[str], include_treeless=False, include_guid: bool = False, include_oidc_accounts: bool = False
240+
tree: Optional[str],
241+
include_treeless=False,
242+
include_guid: bool = False,
243+
include_oidc_accounts: bool = False,
232244
) -> List[Dict[str, Any]]:
233245
"""Return details about all users.
234246
@@ -245,7 +257,12 @@ def get_all_user_details(
245257
else:
246258
query = query.filter(User.tree == tree)
247259
users = query.all()
248-
return [_get_user_detail(user, include_guid=include_guid, include_oidc_accounts=include_oidc_accounts) for user in users]
260+
return [
261+
_get_user_detail(
262+
user, include_guid=include_guid, include_oidc_accounts=include_oidc_accounts
263+
)
264+
for user in users
265+
]
249266

250267

251268
def get_permissions(username: str, tree: str) -> Set[str]:
@@ -435,10 +452,7 @@ def is_tree_disabled(tree: str) -> bool:
435452

436453

437454
def create_oidc_account(
438-
user_id: str,
439-
provider_id: str,
440-
subject_id: str,
441-
email: Optional[str] = None
455+
user_id: str, provider_id: str, subject_id: str, email: Optional[str] = None
442456
) -> None:
443457
"""Create a new OIDC account association."""
444458
oidc_account = OIDCAccount(
@@ -454,20 +468,25 @@ def create_oidc_account(
454468
def get_oidc_account(provider_id: str, subject_id: str) -> Optional[str]:
455469
"""Get user ID by OIDC provider_id and subject_id."""
456470
query = user_db.session.query(OIDCAccount.user_id) # pylint: disable=no-member
457-
oidc_account = query.filter_by(provider_id=provider_id, subject_id=subject_id).scalar()
471+
oidc_account = query.filter_by(
472+
provider_id=provider_id, subject_id=subject_id
473+
).scalar()
458474
return oidc_account
459475

460476

461477
def get_user_oidc_accounts(user_id: str) -> List[Dict[str, Any]]:
462478
"""Get all OIDC accounts associated with a user."""
463479
query = user_db.session.query(OIDCAccount) # pylint: disable=no-member
464480
oidc_accounts = query.filter_by(user_id=user_id).all()
465-
return [{
466-
"provider_id": account.provider_id,
467-
"subject_id": account.subject_id,
468-
"email": account.email,
469-
"created_at": account.created_at,
470-
} for account in oidc_accounts]
481+
return [
482+
{
483+
"provider_id": account.provider_id,
484+
"subject_id": account.subject_id,
485+
"email": account.email,
486+
"created_at": account.created_at,
487+
}
488+
for account in oidc_accounts
489+
]
471490

472491

473492
class User(user_db.Model): # type: ignore
@@ -528,14 +547,20 @@ class OIDCAccount(user_db.Model): # type: ignore
528547
__tablename__ = "oidc_accounts"
529548

530549
id = mapped_column(sa.Integer, primary_key=True, autoincrement=True)
531-
user_id = mapped_column(GUID, sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
550+
user_id = mapped_column(
551+
GUID, sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
552+
)
532553
provider_id = mapped_column(sa.String(64), nullable=False)
533554
subject_id = mapped_column(sa.String(255), nullable=False)
534555
email = mapped_column(sa.String(255), nullable=True, index=True)
535-
created_at = mapped_column(sa.DateTime, nullable=False, server_default=sa.func.now())
556+
created_at = mapped_column(
557+
sa.DateTime, nullable=False, server_default=sa.func.now()
558+
)
536559

537560
__table_args__ = (
538-
sa.UniqueConstraint('provider_id', 'subject_id', name='uq_oidc_provider_subject'),
561+
sa.UniqueConstraint(
562+
"provider_id", "subject_id", name="uq_oidc_provider_subject"
563+
),
539564
)
540565

541566
def __repr__(self):

0 commit comments

Comments
 (0)