Skip to content

Commit 7362e63

Browse files
committed
fix: update KeyPairData dotfiles type to bytes and adjust related logic in ArtifactDBSource and tests
1 parent c52fb84 commit 7362e63

File tree

7 files changed

+64
-81
lines changed

7 files changed

+64
-81
lines changed

src/ai/backend/manager/data/keypair/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,5 @@ class KeyPairData:
3737
rate_limit: int
3838
ssh_public_key: Optional[str]
3939
ssh_private_key: Optional[str]
40-
dotfiles: str
40+
dotfiles: bytes
4141
bootstrap_script: str

src/ai/backend/manager/models/keypair/row.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def to_data(self) -> KeyPairData:
175175
rate_limit=self.rate_limit if self.rate_limit is not None else 0,
176176
ssh_public_key=self.ssh_public_key,
177177
ssh_private_key=self.ssh_private_key,
178-
dotfiles=self.dotfiles.decode("utf-8") if self.dotfiles else "",
178+
dotfiles=self.dotfiles if self.dotfiles else b"\x90",
179179
bootstrap_script=self.bootstrap_script,
180180
)
181181

src/ai/backend/manager/repositories/artifact/db_source/db_source.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -647,11 +647,7 @@ async def approve_artifact(self, revision_id: uuid.UUID) -> ArtifactRevisionData
647647
)
648648

649649
result = await db_sess.execute(update_stmt)
650-
updated_id = result.scalar_one_or_none()
651-
if updated_id is None:
652-
raise ArtifactUpdateError()
653-
654-
updated_row = await db_sess.get(ArtifactRevisionRow, updated_id)
650+
updated_row = result.scalars().one_or_none()
655651
if updated_row is None:
656652
raise ArtifactUpdateError()
657653

@@ -674,11 +670,7 @@ async def reject_artifact(self, revision_id: uuid.UUID) -> ArtifactRevisionData:
674670
)
675671

676672
result = await db_sess.execute(update_stmt)
677-
updated_id = result.scalar_one_or_none()
678-
if updated_id is None:
679-
raise ArtifactUpdateError()
680-
681-
updated_row = await db_sess.get(ArtifactRevisionRow, updated_id)
673+
updated_row = result.scalars().one_or_none()
682674
if updated_row is None:
683675
raise ArtifactUpdateError()
684676

tests/unit/manager/services/auth/test_authorize.py

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,14 @@ def setup_successful_auth(mock_auth_repository, mock_hook_plugin_ctx):
5757
mock_hook_plugin_ctx.dispatch.return_value = HookResult(
5858
status=HookResults.PASSED, result=None, reason=None
5959
)
60-
mock_auth_repository.check_credential_with_migration.return_value = {
61-
"uuid": UUID("12345678-1234-5678-1234-567812345678"),
62-
"email": "test@example.com",
63-
"role": UserRole.USER,
64-
"status": UserStatus.ACTIVE,
65-
"password_changed_at": None,
66-
}
60+
mock_user = MagicMock()
61+
mock_user.uuid = UUID("12345678-1234-5678-1234-567812345678")
62+
mock_user.email = "test@example.com"
63+
mock_user.role = UserRole.USER
64+
mock_user.status = UserStatus.ACTIVE
65+
mock_user.password_changed_at = None
66+
mock_user.__getitem__ = lambda self, key: getattr(self, key)
67+
mock_auth_repository.check_credential_with_migration.return_value = mock_user
6768
mock_user_row = MagicMock()
6869
mock_user_row.get_main_keypair_row.return_value = MagicMock(
6970
access_key="test_access_key",
@@ -149,14 +150,14 @@ async def test_authorize_with_hook_authorization(
149150
stoken=None,
150151
)
151152

152-
# Hook returns user data
153-
hook_user = {
154-
"uuid": UUID("87654321-4321-8765-4321-876543218765"),
155-
"email": "hook@example.com",
156-
"role": UserRole.ADMIN,
157-
"status": UserStatus.ACTIVE,
158-
"password_changed_at": None,
159-
}
153+
# Hook returns user data as MagicMock for attribute access
154+
hook_user = MagicMock()
155+
hook_user.uuid = UUID("87654321-4321-8765-4321-876543218765")
156+
hook_user.email = "hook@example.com"
157+
hook_user.role = UserRole.ADMIN
158+
hook_user.status = UserStatus.ACTIVE
159+
hook_user.password_changed_at = None
160+
hook_user.__getitem__ = lambda self, key: getattr(self, key)
160161
mock_hook_plugin_ctx.dispatch.return_value = HookResult(
161162
status=HookResults.PASSED,
162163
result=hook_user,
@@ -177,7 +178,7 @@ async def test_authorize_with_hook_authorization(
177178
assert result.authorization_result is not None
178179
assert result.authorization_result.access_key == "hook_access_key"
179180
assert result.authorization_result.secret_key == "hook_secret_key"
180-
assert result.authorization_result.user_id == hook_user["uuid"]
181+
assert result.authorization_result.user_id == hook_user.uuid
181182
assert result.authorization_result.role == UserRole.ADMIN
182183

183184

@@ -202,13 +203,14 @@ async def test_authorize_with_password_expiry(
202203

203204
# Setup expired password
204205
password_changed_at = datetime.now(tz=UTC) - timedelta(days=100)
205-
mock_auth_repository.check_credential_with_migration.return_value = {
206-
"uuid": UUID("12345678-1234-5678-1234-567812345678"),
207-
"email": "expired@example.com",
208-
"role": UserRole.USER,
209-
"status": UserStatus.ACTIVE,
210-
"password_changed_at": password_changed_at,
211-
}
206+
mock_user = MagicMock()
207+
mock_user.uuid = UUID("12345678-1234-5678-1234-567812345678")
208+
mock_user.email = "expired@example.com"
209+
mock_user.role = UserRole.USER
210+
mock_user.status = UserStatus.ACTIVE
211+
mock_user.password_changed_at = password_changed_at
212+
mock_user.__getitem__ = lambda self, key: getattr(self, key)
213+
mock_auth_repository.check_credential_with_migration.return_value = mock_user
212214
mock_auth_repository.get_current_time.return_value = datetime.now(tz=UTC)
213215

214216
mock_hook_plugin_ctx.dispatch.return_value = HookResult(
@@ -238,13 +240,14 @@ async def test_authorize_with_post_hook_response(
238240
)
239241

240242
# Setup successful credential check
241-
mock_auth_repository.check_credential_with_migration.return_value = {
242-
"uuid": UUID("12345678-1234-5678-1234-567812345678"),
243-
"email": "test@example.com",
244-
"role": UserRole.USER,
245-
"status": UserStatus.ACTIVE,
246-
"password_changed_at": None,
247-
}
243+
mock_user = MagicMock()
244+
mock_user.uuid = UUID("12345678-1234-5678-1234-567812345678")
245+
mock_user.email = "test@example.com"
246+
mock_user.role = UserRole.USER
247+
mock_user.status = UserStatus.ACTIVE
248+
mock_user.password_changed_at = None
249+
mock_user.__getitem__ = lambda self, key: getattr(self, key)
250+
mock_auth_repository.check_credential_with_migration.return_value = mock_user
248251

249252
# Mock user row
250253
mock_user_row = MagicMock()

tests/unit/manager/services/scaling_group/test_scaling_group_service.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -434,12 +434,11 @@ async def test_check_scaling_group_raises_session_type_not_allowed(
434434
) -> None:
435435
"""Test that check_scaling_group raises ScalingGroupSessionTypeNotAllowed (422)
436436
when requesting BATCH session on INTERACTIVE-only scaling group"""
437-
mock_sgroup = {
438-
"name": "test-sgroup",
439-
"scheduler_opts": ScalingGroupOpts(
440-
allowed_session_types=[SessionTypes.INTERACTIVE],
441-
),
442-
}
437+
mock_sgroup = MagicMock()
438+
mock_sgroup.name = "test-sgroup"
439+
mock_sgroup.scheduler_opts = ScalingGroupOpts(
440+
allowed_session_types=[SessionTypes.INTERACTIVE],
441+
)
443442

444443
with patch(
445444
"ai.backend.manager.registry.query_allowed_sgroups",
@@ -462,12 +461,11 @@ async def test_check_scaling_group_succeeds_with_allowed_session_type(
462461
mock_conn: MagicMock,
463462
) -> None:
464463
"""Test that check_scaling_group succeeds when session type is allowed"""
465-
mock_sgroup = {
466-
"name": "test-sgroup",
467-
"scheduler_opts": ScalingGroupOpts(
468-
allowed_session_types=[SessionTypes.INTERACTIVE],
469-
),
470-
}
464+
mock_sgroup = MagicMock()
465+
mock_sgroup.name = "test-sgroup"
466+
mock_sgroup.scheduler_opts = ScalingGroupOpts(
467+
allowed_session_types=[SessionTypes.INTERACTIVE],
468+
)
471469

472470
with patch(
473471
"ai.backend.manager.registry.query_allowed_sgroups",

tests/unit/manager/services/session/test_session_service.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ async def test_empty_status_history(
389389
sample_session_id: SessionId,
390390
sample_access_key: AccessKey,
391391
) -> None:
392-
"""Test getting empty status history"""
392+
"""Test getting empty status history returns empty dict when None"""
393393
mock_session = MagicMock()
394394
mock_session.id = sample_session_id
395395
mock_session.status_history = None
@@ -402,7 +402,7 @@ async def test_empty_status_history(
402402
result = await session_service.get_status_history(action)
403403

404404
assert result.session_id == sample_session_id
405-
assert result.status_history is None
405+
assert result.status_history == {}
406406

407407

408408
# ==================== DestroySession Tests ====================
@@ -624,7 +624,7 @@ async def test_success(
624624
mock_kernel.image = "cr.backend.ai/stable/python:latest"
625625
mock_kernel.architecture = "x86_64"
626626
mock_kernel.registry = "cr.backend.ai"
627-
mock_kernel.container_id = uuid4()
627+
mock_kernel.container_id = str(uuid4())
628628
mock_kernel.occupied_slots = ResourceSlot({"cpu": 1, "mem": 1024})
629629
mock_kernel.occupied_shares = {}
630630

tests/unit/manager/test_scaling_group_for_enqueue.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,28 +14,23 @@
1414
from ai.backend.manager.registry import check_scaling_group
1515

1616

17+
def _create_mock_sgroup(name: str, allowed_session_types: list[str]) -> MagicMock:
18+
"""Create a mock scaling group with proper attribute access."""
19+
mock = MagicMock()
20+
mock.name = name
21+
mock.scheduler_opts = ScalingGroupOpts.from_json({
22+
"allowed_session_types": allowed_session_types,
23+
})
24+
return mock
25+
26+
1727
@pytest.mark.asyncio
1828
@mock.patch("ai.backend.manager.registry.query_allowed_sgroups")
1929
async def test_allowed_session_types_check(mock_query):
2030
mock_query.return_value = [
21-
{
22-
"name": "a",
23-
"scheduler_opts": ScalingGroupOpts().from_json({
24-
"allowed_session_types": ["batch"],
25-
}),
26-
},
27-
{
28-
"name": "b",
29-
"scheduler_opts": ScalingGroupOpts().from_json({
30-
"allowed_session_types": ["interactive"],
31-
}),
32-
},
33-
{
34-
"name": "c",
35-
"scheduler_opts": ScalingGroupOpts().from_json({
36-
"allowed_session_types": ["batch", "interactive"],
37-
}),
38-
},
31+
_create_mock_sgroup("a", ["batch"]),
32+
_create_mock_sgroup("b", ["interactive"]),
33+
_create_mock_sgroup("c", ["batch", "interactive"]),
3934
]
4035
mock_conn = MagicMock()
4136
mock_sess_ctx = MagicMock()
@@ -126,12 +121,7 @@ async def test_allowed_session_types_check(mock_query):
126121
# No preferred scaling group with a non-empty list of allowed sgroups
127122

128123
mock_query.return_value = [
129-
{
130-
"name": "a",
131-
"scheduler_opts": ScalingGroupOpts.from_json({
132-
"allowed_session_types": ["batch"],
133-
}),
134-
},
124+
_create_mock_sgroup("a", ["batch"]),
135125
]
136126

137127
session_type = SessionTypes.BATCH

0 commit comments

Comments
 (0)