Skip to content

Commit 51bdadc

Browse files
authored
Add many-to-many to plaforms-sign_keys (AlmaLinux/build-system#370) (#1041)
* Add many-to-many to plaforms-sign_keys (AlmaLinux/build-system#370) - alembic migration - change in models - added a test for one platform with 2 keys * Changing POST sign_keys schema to accept list of platforms + bug fixes * Important bug fix for alembic downgrade: - Disable delete on cascade for downgrade to preserve referencing tables * Adding platform_ids to /sign-keys/ response
1 parent ccb49aa commit 51bdadc

File tree

7 files changed

+253
-38
lines changed

7 files changed

+253
-38
lines changed
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
"""Add many to many relationship between platforms and sign_keys
2+
3+
Revision ID: 63bc4e0b699f
4+
Revises: 764a67b23038
5+
Create Date: 2024-11-01 12:26:26.326314
6+
7+
"""
8+
9+
import sqlalchemy as sa
10+
from alembic import op
11+
from sqlalchemy.engine import reflection
12+
13+
# revision identifiers, used by Alembic.
14+
revision = '63bc4e0b699f'
15+
down_revision = '9977cc722e24'
16+
branch_labels = None
17+
depends_on = None
18+
19+
20+
def disable_fk_checks():
21+
op.execute(sa.text("SET session_replication_role = 'replica'"))
22+
23+
def enable_fk_checks():
24+
op.execute(sa.text("SET session_replication_role = 'origin'"))
25+
26+
def get_columns(table_name):
27+
# Use SQLAlchemy's Inspector to retrieve the columns of the specified table
28+
conn = op.get_bind()
29+
inspector = reflection.Inspector.from_engine(conn)
30+
columns = [col["name"] for col in inspector.get_columns(table_name)]
31+
return columns
32+
33+
34+
def create_backup_table(table_name):
35+
backup_table = f'{table_name}_backup'
36+
op.execute(
37+
sa.text(f"CREATE TABLE {backup_table} AS TABLE {table_name} WITH DATA")
38+
)
39+
40+
41+
def drop_backup_table(backup_table_name):
42+
op.execute(sa.text(f"DROP TABLE IF EXISTS {backup_table_name}"))
43+
44+
45+
def restore_from_backup(table_name):
46+
backup_table = f'{table_name}_backup'
47+
columns = get_columns(table_name)
48+
column_list = ', '.join(columns)
49+
disable_fk_checks()
50+
op.execute(sa.text(f"DELETE FROM {table_name}"))
51+
op.execute(
52+
sa.text(
53+
f"INSERT INTO {table_name} ({column_list}) SELECT {column_list} FROM {backup_table}"
54+
)
55+
)
56+
enable_fk_checks()
57+
58+
59+
def create_association_table():
60+
op.create_table(
61+
'platforms_sign_keys',
62+
sa.Column(
63+
'platform_id',
64+
sa.Integer,
65+
sa.ForeignKey('platforms.id', ondelete="CASCADE"),
66+
primary_key=True,
67+
),
68+
sa.Column(
69+
'sign_key_id',
70+
sa.Integer,
71+
sa.ForeignKey('sign_keys.id', ondelete="CASCADE"),
72+
primary_key=True,
73+
),
74+
)
75+
76+
77+
def upgrade():
78+
create_backup_table("platforms")
79+
create_backup_table("sign_keys")
80+
op.drop_constraint(
81+
'sign_keys_platform_id_fkey', 'sign_keys', type_='foreignkey'
82+
)
83+
create_association_table()
84+
op.execute(
85+
sa.text(
86+
"""
87+
INSERT INTO platforms_sign_keys (platform_id, sign_key_id)
88+
SELECT platform_id, id FROM sign_keys
89+
WHERE platform_id IS NOT NULL
90+
"""
91+
)
92+
)
93+
op.drop_column("sign_keys", "platform_id")
94+
95+
96+
def downgrade():
97+
op.add_column(
98+
"sign_keys", sa.Column("platform_id", sa.Integer, nullable=True)
99+
)
100+
op.execute(
101+
sa.text(
102+
"""
103+
UPDATE sign_keys
104+
SET platform_id = (
105+
SELECT platform_id
106+
FROM platforms_sign_keys
107+
WHERE platforms_sign_keys.sign_key_id = sign_keys.id
108+
LIMIT 1
109+
)
110+
"""
111+
)
112+
)
113+
op.create_foreign_key(
114+
'sign_keys_platform_id_fkey',
115+
'sign_keys',
116+
'platforms',
117+
['platform_id'],
118+
['id'],
119+
)
120+
op.drop_table('platforms_sign_keys')
121+
122+
restore_from_backup("platforms")
123+
restore_from_backup("sign_keys")
124+
125+
# Neccessary to preserve id sequences
126+
op.execute(
127+
sa.text(
128+
"""
129+
SELECT setval('platforms_id_seq', MAX(id)) FROM platforms;
130+
"""
131+
)
132+
)
133+
op.execute(
134+
sa.text(
135+
"""
136+
SELECT setval('sign_keys_id_seq', MAX(id)) FROM sign_keys;
137+
"""
138+
)
139+
)
140+
141+
drop_backup_table("platforms_backup")
142+
drop_backup_table("sign_keys_backup")

alws/crud/sign_key.py

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
import datetime
2+
import logging
13
import typing
24

5+
from sqlalchemy import and_, update
36
from sqlalchemy.ext.asyncio import AsyncSession
47
from sqlalchemy.future import select
5-
from sqlalchemy import and_, update
68
from sqlalchemy.orm import selectinload
7-
import datetime
89

910
from alws import models
1011
from alws.crud.user import get_user
@@ -25,11 +26,14 @@ async def get_sign_keys(
2526
) -> typing.List[models.SignKey]:
2627
limited_user = await get_user(db, user.id)
2728
result = await db.execute(
28-
select(models.SignKey).where(models.SignKey.active).options(
29+
select(models.SignKey)
30+
.where(models.SignKey.active)
31+
.options(
2932
selectinload(models.SignKey.owner),
3033
selectinload(models.SignKey.roles).selectinload(
3134
models.UserRole.actions
3235
),
36+
selectinload(models.SignKey.platforms),
3337
)
3438
)
3539
suitable_keys = [
@@ -40,6 +44,17 @@ async def get_sign_keys(
4044
return suitable_keys
4145

4246

47+
async def get_sign_key(db: AsyncSession, key_id: int):
48+
sign_key = await db.execute(
49+
select(models.SignKey)
50+
.where(models.SignKey.keyid == key_id)
51+
.options(selectinload(models.SignKey.platforms))
52+
)
53+
if not sign_key:
54+
raise DataNotFoundError(f"Sign key with ID {key_id} does not exist")
55+
return sign_key.scalars().first()
56+
57+
4358
async def create_sign_key(
4459
db: AsyncSession, payload: sign_schema.SignKeyCreate
4560
) -> models.SignKey:
@@ -50,21 +65,34 @@ async def create_sign_key(
5065
raise SignKeyAlreadyExistsError(
5166
f"Key with keyid {payload.keyid} already exists"
5267
)
53-
if payload.platform_id:
54-
check_platform = await db.execute(
55-
select(models.Platform.id).where(
56-
models.Platform.id == payload.platform_id
57-
)
68+
model = payload.model_dump()
69+
platform_ids = model.pop('platform_ids', None)
70+
sign_key = models.SignKey(**model)
71+
72+
if platform_ids:
73+
check_platforms = await db.execute(
74+
select(models.Platform).where(models.Platform.id.in_(platform_ids))
5875
)
59-
if not check_platform.scalars().first():
76+
platform_instances = check_platforms.scalars().all()
77+
if not platform_instances:
6078
raise PlatformMissingError(
61-
f"No platform with id '{payload.platform_id}' "
62-
"exists in the system"
79+
f"No platforms with ids '{platform_ids}' exist in the system"
80+
)
81+
if len(platform_instances) < len(platform_ids):
82+
db_platform_ids = [pl.id for pl in platform_instances]
83+
missing_platform_ids = [
84+
platform_id
85+
for platform_id in platform_ids
86+
if platform_id not in db_platform_ids
87+
]
88+
logging.warning(
89+
f"Platforms with ids: '{missing_platform_ids}' "
90+
"are missing in the system. Did not add them to the sign key."
6391
)
64-
sign_key = models.SignKey(**payload.model_dump())
92+
sign_key.platforms = platform_instances
6593
db.add(sign_key)
6694
await db.flush()
67-
await db.refresh(sign_key)
95+
await db.refresh(sign_key, attribute_names=['platforms'])
6896
return sign_key
6997

7098

alws/models.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,23 @@ def finished_at(cls) -> Mapped[Optional[datetime.datetime]]:
230230
),
231231
)
232232

233+
platforms_sign_keys = sqlalchemy.Table(
234+
"platforms_sign_keys",
235+
Base.metadata,
236+
sqlalchemy.Column(
237+
"platform_id",
238+
sqlalchemy.Integer,
239+
sqlalchemy.ForeignKey("platforms.id", ondelete="CASCADE"),
240+
primary_key=True,
241+
),
242+
sqlalchemy.Column(
243+
"sign_key_id",
244+
sqlalchemy.Integer,
245+
sqlalchemy.ForeignKey("sign_keys.id", ondelete="CASCADE"),
246+
primary_key=True,
247+
),
248+
)
249+
233250

234251
class Platform(PermissionsMixin, Base):
235252
__tablename__ = "platforms"
@@ -244,7 +261,9 @@ class Platform(PermissionsMixin, Base):
244261
type: Mapped[str] = mapped_column(sqlalchemy.Text, nullable=False)
245262
distr_type: Mapped[str] = mapped_column(sqlalchemy.Text, nullable=False)
246263
distr_version: Mapped[str] = mapped_column(sqlalchemy.Text, nullable=False)
247-
pgp_key: Mapped[Optional[str]] = mapped_column(sqlalchemy.Text, nullable=True)
264+
pgp_key: Mapped[Optional[str]] = mapped_column(
265+
sqlalchemy.Text, nullable=True
266+
)
248267
module_build_index: Mapped[int] = mapped_column(
249268
sqlalchemy.Integer, default=1
250269
)
@@ -281,7 +300,7 @@ class Platform(PermissionsMixin, Base):
281300
"Repository", secondary=PlatformRepo
282301
)
283302
sign_keys: Mapped[List["SignKey"]] = relationship(
284-
"SignKey", back_populates="platform"
303+
"SignKey", secondary=platforms_sign_keys, back_populates="platforms"
285304
)
286305
roles: Mapped[List["UserRole"]] = relationship(
287306
"UserRole", secondary=PlatformRoleMapping
@@ -953,7 +972,7 @@ class Team(PermissionsMixin, Base):
953972
"Product", back_populates="team"
954973
)
955974
test_repositories: Mapped[List["TestRepository"]] = relationship(
956-
"TestRepository", back_populates="team"
975+
"TestRepository", back_populates="team"
957976
)
958977
roles: Mapped[List["UserRole"]] = relationship(
959978
"UserRole",
@@ -1199,7 +1218,9 @@ class TestRepository(PermissionsMixin, TeamMixin, Base):
11991218
back_populates="test_repository",
12001219
cascade="all, delete",
12011220
)
1202-
team: Mapped["Team"] = relationship("Team", back_populates="test_repositories")
1221+
team: Mapped["Team"] = relationship(
1222+
"Team", back_populates="test_repositories"
1223+
)
12031224
roles: Mapped[List["UserRole"]] = relationship(
12041225
"UserRole", secondary=TestRepositoryRoleMapping
12051226
)
@@ -1300,11 +1321,10 @@ class SignKey(PermissionsMixin, Base):
13001321
inserted: Mapped[datetime.datetime] = mapped_column(
13011322
sqlalchemy.DateTime, default=datetime.datetime.utcnow()
13021323
)
1303-
active: Mapped[bool] = mapped_column(
1304-
sqlalchemy.Boolean, default=True
1305-
)
1324+
active: Mapped[bool] = mapped_column(sqlalchemy.Boolean, default=True)
13061325
archived: Mapped[datetime.datetime] = mapped_column(
1307-
sqlalchemy.DateTime, nullable=True,
1326+
sqlalchemy.DateTime,
1327+
nullable=True,
13081328
)
13091329
product_id: Mapped[Optional[int]] = mapped_column(
13101330
sqlalchemy.Integer,
@@ -1317,16 +1337,8 @@ class SignKey(PermissionsMixin, Base):
13171337
product: Mapped["Product"] = relationship(
13181338
'Product', back_populates='sign_keys'
13191339
)
1320-
platform_id: Mapped[Optional[int]] = mapped_column(
1321-
sqlalchemy.Integer,
1322-
sqlalchemy.ForeignKey(
1323-
"platforms.id",
1324-
name="sign_keys_platform_id_fkey",
1325-
),
1326-
nullable=True,
1327-
)
1328-
platform: Mapped["Platform"] = relationship(
1329-
"Platform", back_populates="sign_keys"
1340+
platforms: Mapped[List["Platform"]] = relationship(
1341+
"Platform", secondary=platforms_sign_keys, back_populates="sign_keys"
13301342
)
13311343
build_task_artifacts: Mapped[List["BuildTaskArtifact"]] = relationship(
13321344
"BuildTaskArtifact",

alws/routers/sign_key.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@ async def get_sign_keys(
2727
db: AsyncSession = Depends(AsyncSessionDependency(key=get_async_db_key())),
2828
user=Depends(get_current_user),
2929
):
30-
return await sign_key.get_sign_keys(db, user)
30+
keys = await sign_key.get_sign_keys(db, user)
31+
for key in keys:
32+
if key.platforms:
33+
key.platform_ids = [pl.id for pl in key.platforms]
34+
return keys
3135

3236

3337
@router.post(
@@ -42,7 +46,10 @@ async def create_sign_key(
4246
):
4347
try:
4448
payload.owner_id = user.id
45-
return await sign_key.create_sign_key(db, payload)
49+
key = await sign_key.create_sign_key(db, payload)
50+
if payload.platform_ids:
51+
key.platform_ids = [pl.id for pl in key.platforms]
52+
return key
4653
except (PlatformMissingError, SignKeyAlreadyExistsError) as e:
4754
raise HTTPException(status.HTTP_400_BAD_REQUEST, str(e))
4855

alws/schemas/sign_schema.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class SignKey(BaseModel):
1313
inserted: datetime
1414
active: bool = True
1515
archived: typing.Optional[datetime] = None
16-
platform_id: typing.Optional[int] = None
16+
platform_ids: typing.Optional[typing.List[int]] = None
1717
product_id: typing.Optional[int] = None
1818

1919
class Config:
@@ -26,7 +26,7 @@ class SignKeyCreate(BaseModel):
2626
keyid: str
2727
fingerprint: str
2828
public_url: str
29-
platform_id: typing.Optional[int] = None
29+
platform_ids: typing.Optional[typing.List[int]] = None
3030
owner_id: typing.Optional[int] = None
3131

3232

tests/fixtures/sign_keys.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,18 @@ def basic_sign_key_payload() -> dict:
2626
}
2727

2828

29+
@pytest.fixture
30+
def additional_sign_key_payload() -> dict:
31+
keyid = "01234567890ABCDE"
32+
return {
33+
"name": "Test key",
34+
"description": "Test sign key",
35+
"keyid": keyid,
36+
"fingerprint": "01234567890ABCDEF1234567890ABCDEF1234567",
37+
"public_url": "no_url",
38+
}
39+
40+
2941
async def __create_sign_key(
3042
async_session: AsyncSession, payload: dict
3143
) -> SignKey:

0 commit comments

Comments
 (0)