Skip to content

Commit 4359ba5

Browse files
Copilotblarghmateyjkachelpre-commit-ci[bot]
authored
Fix remove_user_contracts to maintain user filtering and only remove managed contracts (#3031)
Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: blarghmatey <[email protected]> Co-authored-by: jkachel <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f1e208e commit 4359ba5

File tree

2 files changed

+103
-3
lines changed

2 files changed

+103
-3
lines changed

b2b/models.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def add_user_contracts(self, user):
162162

163163
return contracts_qs.count()
164164

165-
def remove_user_contracts(self, user): # noqa: ARG002
165+
def remove_user_contracts(self, user):
166166
"""
167167
Remove managed contracts from the given user.
168168
@@ -171,7 +171,13 @@ def remove_user_contracts(self, user): # noqa: ARG002
171171
Returns:
172172
- int: number of contracts removed
173173
"""
174-
return 0
174+
175+
return user.b2b_contracts.through.objects.filter(
176+
user_id=user.id,
177+
contractpage_id__in=self.contracts.filter(
178+
integration_type__in=CONTRACT_MEMBERSHIP_AUTOS
179+
).values_list("id", flat=True),
180+
).delete()
175181

176182
def __str__(self):
177183
"""Return a reasonable representation of the org as a string."""

b2b/models_test.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
import pytest
55

66
from b2b.factories import ContractPageFactory, OrganizationPageFactory
7-
from courses.factories import CourseRunFactory, ProgramFactory
7+
from courses.factories import (
8+
CourseRunFactory,
9+
ProgramFactory,
10+
)
11+
from users.factories import UserFactory
812

913
pytestmark = [pytest.mark.django_db]
1014
FAKE = faker.Faker()
@@ -94,3 +98,93 @@ def test_organization_page_slug_not_overwritten_if_set():
9498
# The slug should still be the custom one
9599
assert org.slug == "custom-slug"
96100
assert org.title == "Test Org Updated"
101+
102+
103+
def test_remove_user_contracts_only_affects_specified_user():
104+
"""Test that remove_user_contracts only removes contracts for the specified user."""
105+
106+
# Create an organization and contracts
107+
organization = OrganizationPageFactory.create()
108+
contract1 = ContractPageFactory.create(
109+
organization=organization,
110+
membership_type="auto",
111+
integration_type="auto",
112+
)
113+
contract2 = ContractPageFactory.create(
114+
organization=organization,
115+
membership_type="managed",
116+
integration_type="managed",
117+
)
118+
119+
# Create two users and add them both to the contracts
120+
user1 = UserFactory.create()
121+
user2 = UserFactory.create()
122+
123+
user1.b2b_contracts.add(contract1, contract2)
124+
user2.b2b_contracts.add(contract1, contract2)
125+
126+
# Verify both users have the contracts
127+
assert user1.b2b_contracts.count() == 2
128+
assert user2.b2b_contracts.count() == 2
129+
130+
# Remove contracts from user1
131+
organization.remove_user_contracts(user1)
132+
133+
# Verify user1's contracts are removed
134+
user1.refresh_from_db()
135+
assert user1.b2b_contracts.count() == 0
136+
137+
# Verify user2's contracts are NOT affected
138+
user2.refresh_from_db()
139+
assert user2.b2b_contracts.count() == 2
140+
assert user2.b2b_contracts.filter(id=contract1.id).exists()
141+
assert user2.b2b_contracts.filter(id=contract2.id).exists()
142+
143+
144+
def test_remove_user_contracts_only_removes_managed_contracts():
145+
"""Test that remove_user_contracts only removes automatically managed contracts."""
146+
147+
# Create an organization with both managed and non-managed contracts
148+
organization = OrganizationPageFactory.create()
149+
150+
# Automatically managed contracts (should be removed)
151+
auto_contract = ContractPageFactory.create(
152+
organization=organization,
153+
membership_type="auto",
154+
integration_type="auto",
155+
)
156+
managed_contract = ContractPageFactory.create(
157+
organization=organization,
158+
membership_type="managed",
159+
integration_type="managed",
160+
)
161+
sso_contract = ContractPageFactory.create(
162+
organization=organization,
163+
membership_type="sso",
164+
integration_type="sso",
165+
)
166+
167+
# Non-managed contract (should NOT be removed)
168+
code_contract = ContractPageFactory.create(
169+
organization=organization,
170+
membership_type="code",
171+
integration_type="code",
172+
)
173+
174+
# Create a user and add all contracts
175+
user = UserFactory.create()
176+
user.b2b_contracts.add(auto_contract, managed_contract, sso_contract, code_contract)
177+
178+
# Verify user has all 4 contracts
179+
assert user.b2b_contracts.count() == 4
180+
181+
# Remove managed contracts from user
182+
organization.remove_user_contracts(user)
183+
184+
# Verify only managed contracts are removed, code contract remains
185+
user.refresh_from_db()
186+
assert user.b2b_contracts.count() == 1
187+
assert not user.b2b_contracts.filter(id=auto_contract.id).exists()
188+
assert not user.b2b_contracts.filter(id=managed_contract.id).exists()
189+
assert not user.b2b_contracts.filter(id=sso_contract.id).exists()
190+
assert user.b2b_contracts.filter(id=code_contract.id).exists()

0 commit comments

Comments
 (0)