Skip to content

Commit 06d1c8e

Browse files
committed
fix(tests): use real SQLAlchemy models in RLS tests
1 parent 12b4549 commit 06d1c8e

File tree

1 file changed

+22
-31
lines changed

1 file changed

+22
-31
lines changed

tests/unit/test_rls.py

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -267,36 +267,34 @@ class TestVerifyRlsAccess:
267267
@pytest.mark.anyio
268268
async def test_returns_true_when_record_found(self, mock_session: AsyncMock):
269269
"""Test that verify_rls_access returns True when record is accessible."""
270-
# Mock table class
271-
mock_table = MagicMock()
272-
mock_table.id = MagicMock()
273-
mock_table.__tablename__ = "test_table"
274-
275270
# Mock query result to return a record
276271
mock_result = MagicMock()
277272
mock_result.scalar_one_or_none.return_value = MagicMock()
278273
mock_session.execute.return_value = mock_result
279274

280-
result = await verify_rls_access(mock_session, mock_table, uuid.uuid4())
275+
# Use actual SQLAlchemy model for testing
276+
from tracecat.db.models import Workflow
277+
278+
result = await verify_rls_access(mock_session, Workflow, uuid.uuid4())
281279

282280
assert result is True
281+
mock_session.execute.assert_called_once()
283282

284283
@pytest.mark.anyio
285284
async def test_returns_false_when_record_not_found(self, mock_session: AsyncMock):
286285
"""Test that verify_rls_access returns False when RLS blocks access."""
287-
# Mock table class
288-
mock_table = MagicMock()
289-
mock_table.id = MagicMock()
290-
mock_table.__tablename__ = "test_table"
291-
292286
# Mock query result to return None (RLS blocked)
293287
mock_result = MagicMock()
294288
mock_result.scalar_one_or_none.return_value = None
295289
mock_session.execute.return_value = mock_result
296290

297-
result = await verify_rls_access(mock_session, mock_table, uuid.uuid4())
291+
# Use actual SQLAlchemy model for testing
292+
from tracecat.db.models import Workflow
293+
294+
result = await verify_rls_access(mock_session, Workflow, uuid.uuid4())
298295

299296
assert result is False
297+
mock_session.execute.assert_called_once()
300298

301299

302300
class TestRequireRlsAccess:
@@ -309,11 +307,10 @@ async def test_noop_when_rls_disabled(
309307
"""Test that require_rls_access does nothing when RLS is disabled."""
310308
monkeypatch.setattr("tracecat.db.rls.is_rls_enabled", lambda: False)
311309

312-
mock_table = MagicMock()
313-
mock_table.__tablename__ = "test_table"
310+
from tracecat.db.models import Workflow
314311

315312
# Should not raise even without setting up mock
316-
await require_rls_access(mock_session, mock_table, uuid.uuid4())
313+
await require_rls_access(mock_session, Workflow, uuid.uuid4())
317314

318315
@pytest.mark.anyio
319316
async def test_passes_when_access_allowed(
@@ -325,17 +322,15 @@ async def test_passes_when_access_allowed(
325322
{FeatureFlag.RLS_ENABLED},
326323
)
327324

328-
mock_table = MagicMock()
329-
mock_table.id = MagicMock()
330-
mock_table.__tablename__ = "test_table"
325+
from tracecat.db.models import Workflow
331326

332327
# Mock query result to return a record
333328
mock_result = MagicMock()
334329
mock_result.scalar_one_or_none.return_value = MagicMock()
335330
mock_session.execute.return_value = mock_result
336331

337332
# Should not raise
338-
await require_rls_access(mock_session, mock_table, uuid.uuid4())
333+
await require_rls_access(mock_session, Workflow, uuid.uuid4())
339334

340335
@pytest.mark.anyio
341336
async def test_raises_when_access_denied(
@@ -347,13 +342,11 @@ async def test_raises_when_access_denied(
347342
{FeatureFlag.RLS_ENABLED},
348343
)
349344

345+
from tracecat.db.models import Workflow
346+
350347
# Set role context
351348
ctx_role.set(test_role)
352349

353-
mock_table = MagicMock()
354-
mock_table.id = MagicMock()
355-
mock_table.__tablename__ = "test_table"
356-
357350
# Mock query result to return None (RLS blocked)
358351
mock_result = MagicMock()
359352
mock_result.scalar_one_or_none.return_value = None
@@ -362,10 +355,10 @@ async def test_raises_when_access_denied(
362355
try:
363356
with pytest.raises(TracecatRLSViolationError) as exc_info:
364357
await require_rls_access(
365-
mock_session, mock_table, uuid.uuid4(), operation="delete"
358+
mock_session, Workflow, uuid.uuid4(), operation="delete"
366359
)
367360

368-
assert exc_info.value.table == "test_table"
361+
assert exc_info.value.table == "workflow"
369362
assert exc_info.value.operation == "delete"
370363
assert exc_info.value.org_id == str(test_role.organization_id)
371364
assert exc_info.value.workspace_id == str(test_role.workspace_id)
@@ -382,11 +375,9 @@ async def test_logs_violation_on_access_denied(
382375
{FeatureFlag.RLS_ENABLED},
383376
)
384377

385-
ctx_role.set(test_role)
378+
from tracecat.db.models import Workflow
386379

387-
mock_table = MagicMock()
388-
mock_table.id = MagicMock()
389-
mock_table.__tablename__ = "test_table"
380+
ctx_role.set(test_role)
390381

391382
mock_result = MagicMock()
392383
mock_result.scalar_one_or_none.return_value = None
@@ -396,12 +387,12 @@ async def test_logs_violation_on_access_denied(
396387
try:
397388
with pytest.raises(TracecatRLSViolationError):
398389
await require_rls_access(
399-
mock_session, mock_table, uuid.uuid4(), operation="update"
390+
mock_session, Workflow, uuid.uuid4(), operation="update"
400391
)
401392

402393
mock_audit.assert_called_once()
403394
call_kwargs = mock_audit.call_args[1]
404-
assert call_kwargs["table"] == "test_table"
395+
assert call_kwargs["table"] == "workflow"
405396
assert call_kwargs["operation"] == "update"
406397
finally:
407398
ctx_role.set(None)

0 commit comments

Comments
 (0)