Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from fastapi import HTTPException, Request

from lcfs.db.models.compliance.ComplianceReportStatus import ComplianceReportStatusEnum
from lcfs.db.models.user.Role import RoleEnum
from lcfs.web.api.compliance_report.validation import ComplianceReportValidation
from lcfs.web.api.compliance_report.repo import ComplianceReportRepository

Expand Down Expand Up @@ -114,3 +115,76 @@ async def test_validate_editable_with_various_non_editable_statuses(

assert exc_info.value.status_code == 403
assert "cannot be edited" in exc_info.value.detail


def _make_validator_with_user(mock_repo, org_id=1, roles=None):
"""Build a ComplianceReportValidation instance for the given org/roles."""
if roles is None:
roles = []
user = MockUser(organization_id=org_id, roles=roles)
user.role_names = roles # user_has_roles reads role_names
request = MockRequest(user)
return ComplianceReportValidation(request=request, repo=mock_repo)


@pytest.mark.anyio
class TestValidateOrganizationAccessByGroupUuid:
"""Tests for validate_organization_access_by_group_uuid."""

async def test_no_reports_raises_404(self, mock_repo):
"""Returns 404 when the group UUID matches no reports."""
mock_repo.get_compliance_report_chain.return_value = []
validator = _make_validator_with_user(mock_repo)

with pytest.raises(HTTPException) as exc_info:
await validator.validate_organization_access_by_group_uuid("unknown-uuid")

assert exc_info.value.status_code == 404

async def test_supplier_same_org_passes(self, mock_repo):
"""Supplier belonging to the report's org passes without exception."""
report = MockComplianceReport(ComplianceReportStatusEnum.Draft, organization_id=1)
mock_repo.get_compliance_report_chain.return_value = [report]
validator = _make_validator_with_user(mock_repo, org_id=1, roles=[])

await validator.validate_organization_access_by_group_uuid("test-uuid")

async def test_supplier_different_org_raises_403(self, mock_repo):
"""Supplier from a different org is rejected with 403."""
report = MockComplianceReport(ComplianceReportStatusEnum.Draft, organization_id=1)
mock_repo.get_compliance_report_chain.return_value = [report]
validator = _make_validator_with_user(mock_repo, org_id=99, roles=[])

with pytest.raises(HTTPException) as exc_info:
await validator.validate_organization_access_by_group_uuid("test-uuid")

assert exc_info.value.status_code == 403
assert "does not have access" in exc_info.value.detail

async def test_government_user_any_org_passes(self, mock_repo):
"""Government user can access reports from any organization."""
report = MockComplianceReport(ComplianceReportStatusEnum.Draft, organization_id=42)
mock_repo.get_compliance_report_chain.return_value = [report]

user = MockUser(organization_id=1)
user.role_names = [RoleEnum.GOVERNMENT]
request = MockRequest(user)
validator = ComplianceReportValidation(request=request, repo=mock_repo)

await validator.validate_organization_access_by_group_uuid("test-uuid")

async def test_supplier_no_organization_raises_403(self, mock_repo):
"""Supplier with no organization is denied access."""
report = MockComplianceReport(ComplianceReportStatusEnum.Draft, organization_id=1)
mock_repo.get_compliance_report_chain.return_value = [report]

user = MockUser(organization_id=1)
user.role_names = []
user.organization = None
request = MockRequest(user)
validator = ComplianceReportValidation(request=request, repo=mock_repo)

with pytest.raises(HTTPException) as exc_info:
await validator.validate_organization_access_by_group_uuid("test-uuid")

assert exc_info.value.status_code == 403
80 changes: 80 additions & 0 deletions backend/lcfs/tests/organization/test_organization_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,3 +458,83 @@ async def test_get_compliance_report_by_id_success(
mock_compliance_report_validation.validate_compliance_report_access.assert_awaited_once_with(
mock.ANY
)


_UPDATE_USER_PAYLOAD = {
"title": "Mr",
"keycloak_username": "testuser",
"keycloak_email": "testuser@example.com",
"first_name": "Test",
"last_name": "User",
"is_active": True,
"phone": None,
"mobile_phone": None,
"organization_id": 1,
"role_names": [],
}

_USER_RESPONSE = {
"user_profile_id": 42,
"keycloak_username": "testuser",
"keycloak_email": "testuser@example.com",
"is_active": True,
"first_name": "Test",
"last_name": "User",
"title": "Mr",
"phone": None,
"mobile_phone": None,
"organization_id": 1,
"roles": [],
}


@pytest.mark.anyio
async def test_update_user_supplier_own_org_succeeds(
client: AsyncClient,
fastapi_app: FastAPI,
set_mock_user,
mock_user_services,
):
"""A supplier with MANAGE_USERS can update a user in their own org."""
set_mock_user(fastapi_app, [RoleEnum.MANAGE_USERS], {"organization_id": 1})

mock_user_services.update_user.return_value = None
mock_user_services.get_user_by_id.return_value = _USER_RESPONSE
fastapi_app.dependency_overrides[UserServices] = lambda: mock_user_services

url = fastapi_app.url_path_for("update_user", organization_id=1, user_id=42)
response = await client.put(url, json=_UPDATE_USER_PAYLOAD)

assert response.status_code == 200


@pytest.mark.anyio
async def test_update_user_supplier_other_org_forbidden(
client: AsyncClient,
fastapi_app: FastAPI,
set_mock_user,
):
"""A supplier with MANAGE_USERS cannot update a user in a different org."""
set_mock_user(fastapi_app, [RoleEnum.MANAGE_USERS], {"organization_id": 99})

url = fastapi_app.url_path_for("update_user", organization_id=1, user_id=42)
response = await client.put(url, json=_UPDATE_USER_PAYLOAD)

assert response.status_code == 403


@pytest.mark.anyio
async def test_update_user_supplier_no_organization_forbidden(
client: AsyncClient,
fastapi_app: FastAPI,
set_mock_user,
):
"""A supplier with no organization object set cannot update users."""
# Use org_id=0 so that the mock auth creates an org with id=0,
# which doesn't match organization_id=1 in the URL.
set_mock_user(fastapi_app, [RoleEnum.MANAGE_USERS], {"organization_id": 0})

url = fastapi_app.url_path_for("update_user", organization_id=1, user_id=42)
response = await client.put(url, json=_UPDATE_USER_PAYLOAD)

assert response.status_code == 403
170 changes: 170 additions & 0 deletions backend/lcfs/tests/organizations/test_organizations_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,3 +621,173 @@ async def test_regenerate_organization_link_key_user_context(
assert len(args) >= 3
assert args[0] == 1
assert args[1] == 1


def _mock_org_schema(org_id=1, name="Test Org", total_balance=500, reserved_balance=50):
"""Return an OrganizationResponseSchema-compatible SimpleNamespace."""
from types import SimpleNamespace

return SimpleNamespace(
organization_id=org_id,
name=name,
operating_name=name,
has_early_issuance=False,
total_balance=total_balance,
reserved_balance=reserved_balance,
email=None,
phone=None,
edrms_record=None,
credit_market_contact_name=None,
credit_market_contact_email=None,
credit_market_contact_phone=None,
credit_market_is_seller=False,
credit_market_is_buyer=False,
credits_to_sell=0,
display_in_credit_market=False,
company_details=None,
company_representation_agreements=None,
company_acting_as_aggregator=None,
company_additional_notes=None,
organization_type_id=None,
org_status=None,
org_type=None,
records_address=None,
org_address=None,
org_attorney_address=None,
)


class TestGetOrganizationBalanceStripping:
"""Tests that compliance unit balances are only exposed to authorized users."""

@pytest.mark.anyio
async def test_government_user_sees_balance(
self, fastapi_app: FastAPI, set_mock_user
):
"""Government users receive total_balance and reserved_balance."""
set_mock_user(fastapi_app, [RoleEnum.GOVERNMENT])

with patch(
"lcfs.web.api.organizations.views.OrganizationsService"
) as mock_service_cls:
mock_svc = AsyncMock()
mock_service_cls.return_value = mock_svc
mock_svc.get_organization.return_value = _mock_org_schema(org_id=1)
fastapi_app.dependency_overrides[ServiceDependency] = lambda: mock_svc

async with AsyncClient(app=fastapi_app, base_url="http://test") as client:
response = await client.get("/api/organizations/1")

assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["totalBalance"] == 500
assert data["reservedBalance"] == 50

@pytest.mark.anyio
async def test_supplier_own_org_sees_balance(
self, fastapi_app: FastAPI, set_mock_user
):
"""A supplier accessing their own organization sees balance fields."""
set_mock_user(fastapi_app, [RoleEnum.SUPPLIER], {"organization_id": 1})

with patch(
"lcfs.web.api.organizations.views.OrganizationsService"
) as mock_service_cls:
mock_svc = AsyncMock()
mock_service_cls.return_value = mock_svc
mock_svc.get_organization.return_value = _mock_org_schema(org_id=1)
fastapi_app.dependency_overrides[ServiceDependency] = lambda: mock_svc

async with AsyncClient(app=fastapi_app, base_url="http://test") as client:
response = await client.get("/api/organizations/1")

assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["totalBalance"] == 500
assert data["reservedBalance"] == 50

@pytest.mark.anyio
async def test_supplier_other_org_balance_stripped(
self, fastapi_app: FastAPI, set_mock_user
):
"""A supplier accessing another org's record gets null balance fields."""
set_mock_user(fastapi_app, [RoleEnum.SUPPLIER], {"organization_id": 99})

with patch(
"lcfs.web.api.organizations.views.OrganizationsService"
) as mock_service_cls:
mock_svc = AsyncMock()
mock_service_cls.return_value = mock_svc
mock_svc.get_organization.return_value = _mock_org_schema(
org_id=1, name="Other Org"
)
fastapi_app.dependency_overrides[ServiceDependency] = lambda: mock_svc

async with AsyncClient(app=fastapi_app, base_url="http://test") as client:
response = await client.get("/api/organizations/1")

assert response.status_code == status.HTTP_200_OK
data = response.json()
assert data["totalBalance"] is None
assert data["reservedBalance"] is None


class TestPenaltyEndpointsOrgCheck:
"""Tests that penalty analytics/logs enforce org ownership for suppliers."""

@pytest.mark.anyio
async def test_supplier_penalty_analytics_own_org_allowed(
self, fastapi_app: FastAPI, set_mock_user
):
"""Supplier can fetch penalty analytics for their own org."""
set_mock_user(fastapi_app, [RoleEnum.SUPPLIER], {"organization_id": 1})

with patch(
"lcfs.web.api.organizations.views.OrganizationsService"
) as mock_service_cls:
mock_svc = AsyncMock()
mock_service_cls.return_value = mock_svc
mock_svc.get_penalty_analytics.return_value = MagicMock(
yearly_summaries=[], automatic_total=0, discretionary_total=0
)
fastapi_app.dependency_overrides[ServiceDependency] = lambda: mock_svc

async with AsyncClient(app=fastapi_app, base_url="http://test") as client:
response = await client.get("/api/organizations/1/penalties/analytics")

assert response.status_code == status.HTTP_200_OK

@pytest.mark.anyio
async def test_supplier_penalty_analytics_other_org_forbidden(
self, fastapi_app: FastAPI, set_mock_user
):
"""Supplier cannot fetch penalty analytics for another org."""
set_mock_user(fastapi_app, [RoleEnum.SUPPLIER], {"organization_id": 99})

with patch("lcfs.web.api.organizations.views.OrganizationsService"):
async with AsyncClient(app=fastapi_app, base_url="http://test") as client:
response = await client.get("/api/organizations/1/penalties/analytics")

assert response.status_code == status.HTTP_403_FORBIDDEN

@pytest.mark.anyio
async def test_government_penalty_analytics_any_org_allowed(
self, fastapi_app: FastAPI, set_mock_user
):
"""Government users can fetch penalty analytics for any org."""
set_mock_user(fastapi_app, [RoleEnum.GOVERNMENT])

with patch(
"lcfs.web.api.organizations.views.OrganizationsService"
) as mock_service_cls:
mock_svc = AsyncMock()
mock_service_cls.return_value = mock_svc
mock_svc.get_penalty_analytics.return_value = MagicMock(
yearly_summaries=[], automatic_total=0, discretionary_total=0
)
fastapi_app.dependency_overrides[ServiceDependency] = lambda: mock_svc

async with AsyncClient(app=fastapi_app, base_url="http://test") as client:
response = await client.get("/api/organizations/5/penalties/analytics")

assert response.status_code == status.HTTP_200_OK
17 changes: 17 additions & 0 deletions backend/lcfs/tests/transfer/test_transfer_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

@pytest.mark.anyio
async def test_get_all_transfers_success(transfer_repo, mock_db):
"""No org filter: returns all transfers."""
expected_data = []
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = expected_data
Expand All @@ -25,6 +26,22 @@ async def test_get_all_transfers_success(transfer_repo, mock_db):
assert result == expected_data


@pytest.mark.anyio
async def test_get_all_transfers_with_org_filter(transfer_repo, mock_db):
"""With organization_id filter: query executes with WHERE clause."""
expected_data = []
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = expected_data

mock_db.execute.return_value = mock_result

result = await transfer_repo.get_all_transfers(organization_id=5)

mock_db.execute.assert_called_once()
mock_result.scalars.return_value.all.assert_called_once()
assert result == expected_data


@pytest.mark.anyio
async def test_get_transfer_by_id_success(transfer_repo, mock_db):
transfer_id = 1
Expand Down
Loading
Loading