@@ -267,36 +267,34 @@ class TestVerifyRlsAccess:
267267 @pytest .mark .anyio
268268 async def test_returns_true_when_record_found (self , mock_session : AsyncMock ):
269269 """Test that verify_rls_access returns True when record is accessible."""
270- # Mock table class
271- mock_table = MagicMock ()
272- mock_table .id = MagicMock ()
273- mock_table .__tablename__ = "test_table"
274-
275270 # Mock query result to return a record
276271 mock_result = MagicMock ()
277272 mock_result .scalar_one_or_none .return_value = MagicMock ()
278273 mock_session .execute .return_value = mock_result
279274
280- result = await verify_rls_access (mock_session , mock_table , uuid .uuid4 ())
275+ # Use actual SQLAlchemy model for testing
276+ from tracecat .db .models import Workflow
277+
278+ result = await verify_rls_access (mock_session , Workflow , uuid .uuid4 ())
281279
282280 assert result is True
281+ mock_session .execute .assert_called_once ()
283282
284283 @pytest .mark .anyio
285284 async def test_returns_false_when_record_not_found (self , mock_session : AsyncMock ):
286285 """Test that verify_rls_access returns False when RLS blocks access."""
287- # Mock table class
288- mock_table = MagicMock ()
289- mock_table .id = MagicMock ()
290- mock_table .__tablename__ = "test_table"
291-
292286 # Mock query result to return None (RLS blocked)
293287 mock_result = MagicMock ()
294288 mock_result .scalar_one_or_none .return_value = None
295289 mock_session .execute .return_value = mock_result
296290
297- result = await verify_rls_access (mock_session , mock_table , uuid .uuid4 ())
291+ # Use actual SQLAlchemy model for testing
292+ from tracecat .db .models import Workflow
293+
294+ result = await verify_rls_access (mock_session , Workflow , uuid .uuid4 ())
298295
299296 assert result is False
297+ mock_session .execute .assert_called_once ()
300298
301299
302300class TestRequireRlsAccess :
@@ -309,11 +307,10 @@ async def test_noop_when_rls_disabled(
309307 """Test that require_rls_access does nothing when RLS is disabled."""
310308 monkeypatch .setattr ("tracecat.db.rls.is_rls_enabled" , lambda : False )
311309
312- mock_table = MagicMock ()
313- mock_table .__tablename__ = "test_table"
310+ from tracecat .db .models import Workflow
314311
315312 # Should not raise even without setting up mock
316- await require_rls_access (mock_session , mock_table , uuid .uuid4 ())
313+ await require_rls_access (mock_session , Workflow , uuid .uuid4 ())
317314
318315 @pytest .mark .anyio
319316 async def test_passes_when_access_allowed (
@@ -325,17 +322,15 @@ async def test_passes_when_access_allowed(
325322 {FeatureFlag .RLS_ENABLED },
326323 )
327324
328- mock_table = MagicMock ()
329- mock_table .id = MagicMock ()
330- mock_table .__tablename__ = "test_table"
325+ from tracecat .db .models import Workflow
331326
332327 # Mock query result to return a record
333328 mock_result = MagicMock ()
334329 mock_result .scalar_one_or_none .return_value = MagicMock ()
335330 mock_session .execute .return_value = mock_result
336331
337332 # Should not raise
338- await require_rls_access (mock_session , mock_table , uuid .uuid4 ())
333+ await require_rls_access (mock_session , Workflow , uuid .uuid4 ())
339334
340335 @pytest .mark .anyio
341336 async def test_raises_when_access_denied (
@@ -347,13 +342,11 @@ async def test_raises_when_access_denied(
347342 {FeatureFlag .RLS_ENABLED },
348343 )
349344
345+ from tracecat .db .models import Workflow
346+
350347 # Set role context
351348 ctx_role .set (test_role )
352349
353- mock_table = MagicMock ()
354- mock_table .id = MagicMock ()
355- mock_table .__tablename__ = "test_table"
356-
357350 # Mock query result to return None (RLS blocked)
358351 mock_result = MagicMock ()
359352 mock_result .scalar_one_or_none .return_value = None
@@ -362,10 +355,10 @@ async def test_raises_when_access_denied(
362355 try :
363356 with pytest .raises (TracecatRLSViolationError ) as exc_info :
364357 await require_rls_access (
365- mock_session , mock_table , uuid .uuid4 (), operation = "delete"
358+ mock_session , Workflow , uuid .uuid4 (), operation = "delete"
366359 )
367360
368- assert exc_info .value .table == "test_table "
361+ assert exc_info .value .table == "workflow "
369362 assert exc_info .value .operation == "delete"
370363 assert exc_info .value .org_id == str (test_role .organization_id )
371364 assert exc_info .value .workspace_id == str (test_role .workspace_id )
@@ -382,11 +375,9 @@ async def test_logs_violation_on_access_denied(
382375 {FeatureFlag .RLS_ENABLED },
383376 )
384377
385- ctx_role . set ( test_role )
378+ from tracecat . db . models import Workflow
386379
387- mock_table = MagicMock ()
388- mock_table .id = MagicMock ()
389- mock_table .__tablename__ = "test_table"
380+ ctx_role .set (test_role )
390381
391382 mock_result = MagicMock ()
392383 mock_result .scalar_one_or_none .return_value = None
@@ -396,12 +387,12 @@ async def test_logs_violation_on_access_denied(
396387 try :
397388 with pytest .raises (TracecatRLSViolationError ):
398389 await require_rls_access (
399- mock_session , mock_table , uuid .uuid4 (), operation = "update"
390+ mock_session , Workflow , uuid .uuid4 (), operation = "update"
400391 )
401392
402393 mock_audit .assert_called_once ()
403394 call_kwargs = mock_audit .call_args [1 ]
404- assert call_kwargs ["table" ] == "test_table "
395+ assert call_kwargs ["table" ] == "workflow "
405396 assert call_kwargs ["operation" ] == "update"
406397 finally :
407398 ctx_role .set (None )
0 commit comments