Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 37 additions & 11 deletions todo/repositories/task_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,43 @@
from todo.repositories.common.mongo_repository import MongoRepository
from todo.repositories.task_assignment_repository import TaskAssignmentRepository
from todo.constants.messages import ApiErrors, RepositoryErrors
from todo.constants.task import SORT_FIELD_PRIORITY, SORT_FIELD_ASSIGNEE, SORT_ORDER_DESC
from todo.constants.task import SORT_FIELD_PRIORITY, SORT_FIELD_ASSIGNEE, SORT_ORDER_DESC, TaskStatus


class TaskRepository(MongoRepository):
collection_name = TaskModel.collection_name

@classmethod
def list(
cls, page: int, limit: int, sort_by: str, order: str, user_id: str = None, team_id: str = None
cls,
page: int,
limit: int,
sort_by: str,
order: str,
user_id: str = None,
team_id: str = None,
status_filter: str = None,
) -> List[TaskModel]:
tasks_collection = cls.get_collection()
logger = logging.getLogger(__name__)

if status_filter:
base_filter = {"status": status_filter}
else:
base_filter = {"status": {"$ne": TaskStatus.DONE.value}}

if team_id:
logger.debug(f"TaskRepository.list: team_id={team_id}")
team_assignments = TaskAssignmentRepository.get_by_assignee_id(team_id, "team")
team_task_ids = [assignment.task_id for assignment in team_assignments]
logger.debug(f"TaskRepository.list: team_task_ids={team_task_ids}")
query_filter = {"_id": {"$in": team_task_ids}}
query_filter = {"$and": [base_filter, {"_id": {"$in": team_task_ids}}]}
logger.debug(f"TaskRepository.list: query_filter={query_filter}")
elif user_id:
assigned_task_ids = cls._get_assigned_task_ids_for_user(user_id)
query_filter = {"_id": {"$in": assigned_task_ids}}
query_filter = {"$and": [base_filter, {"_id": {"$in": assigned_task_ids}}]}
else:
query_filter = {}
query_filter = base_filter

if sort_by == SORT_FIELD_PRIORITY:
sort_direction = 1 if order == SORT_ORDER_DESC else -1
Expand Down Expand Up @@ -70,17 +82,25 @@ def _get_assigned_task_ids_for_user(cls, user_id: str) -> List[ObjectId]:
return direct_task_ids + team_task_ids

@classmethod
def count(cls, user_id: str = None, team_id: str = None) -> int:
def count(cls, user_id: str = None, team_id: str = None, status_filter: str = None) -> int:
tasks_collection = cls.get_collection()

if status_filter:
base_filter = {"status": status_filter}
else:
base_filter = {"status": {"$ne": TaskStatus.DONE.value}}

if team_id:
team_assignments = TaskAssignmentRepository.get_by_assignee_id(team_id, "team")
team_task_ids = [assignment.task_id for assignment in team_assignments]
query_filter = {"_id": {"$in": team_task_ids}}
query_filter = {"$and": [base_filter, {"_id": {"$in": team_task_ids}}]}
elif user_id:
assigned_task_ids = cls._get_assigned_task_ids_for_user(user_id)
query_filter = {"$or": [{"createdBy": user_id}, {"_id": {"$in": assigned_task_ids}}]}
query_filter = {
"$and": [base_filter, {"$or": [{"createdBy": user_id}, {"_id": {"$in": assigned_task_ids}}]}]
}
else:
query_filter = {}
query_filter = base_filter
return tasks_collection.count_documents(query_filter)

@classmethod
Expand Down Expand Up @@ -208,10 +228,16 @@ def update(cls, task_id: str, update_data: dict) -> TaskModel | None:
return None

@classmethod
def get_tasks_for_user(cls, user_id: str, page: int, limit: int) -> List[TaskModel]:
def get_tasks_for_user(cls, user_id: str, page: int, limit: int, status_filter: str = None) -> List[TaskModel]:
tasks_collection = cls.get_collection()
assigned_task_ids = cls._get_assigned_task_ids_for_user(user_id)
query = {"_id": {"$in": assigned_task_ids}}

if status_filter:
base_filter = {"status": status_filter}
else:
base_filter = {"status": {"$ne": TaskStatus.DONE.value}}

query = {"$and": [base_filter, {"_id": {"$in": assigned_task_ids}}]}
tasks_cursor = tasks_collection.find(query).skip((page - 1) * limit).limit(limit)
return [TaskModel(**task) for task in tasks_cursor]

Expand Down
2 changes: 2 additions & 0 deletions todo/serializers/get_tasks_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class GetTaskQueryParamsSerializer(serializers.Serializer):

teamId = serializers.CharField(required=False, allow_blank=False, allow_null=True)

status = serializers.CharField(required=False, allow_blank=False, allow_null=True)

def validate(self, attrs):
validated_data = super().validate(attrs)

Expand Down
10 changes: 7 additions & 3 deletions todo/services/task_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def get_tasks(
order: str,
user_id: str,
team_id: str = None,
status_filter: str = None,
) -> GetTasksResponse:
try:
cls._validate_pagination_params(page, limit)
Expand All @@ -89,8 +90,10 @@ def get_tasks(
},
)

tasks = TaskRepository.list(page, limit, sort_by, order, user_id, team_id=team_id)
total_count = TaskRepository.count(user_id, team_id=team_id)
tasks = TaskRepository.list(
page, limit, sort_by, order, user_id, team_id=team_id, status_filter=status_filter
)
total_count = TaskRepository.count(user_id, team_id=team_id, status_filter=status_filter)

if not tasks:
return GetTasksResponse(tasks=[], links=None)
Expand Down Expand Up @@ -665,9 +668,10 @@ def get_tasks_for_user(
user_id: str,
page: int = PaginationConfig.DEFAULT_PAGE,
limit: int = PaginationConfig.DEFAULT_LIMIT,
status_filter: str = None,
) -> GetTasksResponse:
cls._validate_pagination_params(page, limit)
tasks = TaskRepository.get_tasks_for_user(user_id, page, limit)
tasks = TaskRepository.get_tasks_for_user(user_id, page, limit, status_filter=status_filter)
if not tasks:
return GetTasksResponse(tasks=[], links=None)

Expand Down
24 changes: 18 additions & 6 deletions todo/tests/integration/test_task_sorting_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ def test_priority_sorting_integration(self, mock_list, mock_count):
response = self.client.get("/v1/tasks", {"sort_by": "priority", "order": "desc"})

self.assertEqual(response.status_code, status.HTTP_200_OK)
mock_list.assert_called_with(1, 20, SORT_FIELD_PRIORITY, SORT_ORDER_DESC, str(self.user_id), team_id=None)
mock_list.assert_called_with(
1, 20, SORT_FIELD_PRIORITY, SORT_ORDER_DESC, str(self.user_id), team_id=None, status_filter=None
)

@patch("todo.repositories.task_repository.TaskRepository.count")
@patch("todo.repositories.task_repository.TaskRepository.list")
Expand All @@ -36,7 +38,9 @@ def test_due_at_default_order_integration(self, mock_list, mock_count):

self.assertEqual(response.status_code, status.HTTP_200_OK)

mock_list.assert_called_with(1, 20, SORT_FIELD_DUE_AT, SORT_ORDER_ASC, str(self.user_id), team_id=None)
mock_list.assert_called_with(
1, 20, SORT_FIELD_DUE_AT, SORT_ORDER_ASC, str(self.user_id), team_id=None, status_filter=None
)

@patch("todo.repositories.task_repository.TaskRepository.count")
@patch("todo.repositories.task_repository.TaskRepository.list")
Expand All @@ -49,7 +53,9 @@ def test_assignee_sorting_uses_aggregation(self, mock_list, mock_count):
self.assertEqual(response.status_code, status.HTTP_200_OK)

# Assignee sorting now falls back to createdAt sorting
mock_list.assert_called_once_with(1, 20, SORT_FIELD_ASSIGNEE, SORT_ORDER_ASC, str(self.user_id), team_id=None)
mock_list.assert_called_once_with(
1, 20, SORT_FIELD_ASSIGNEE, SORT_ORDER_ASC, str(self.user_id), team_id=None, status_filter=None
)

@patch("todo.repositories.task_repository.TaskRepository.count")
@patch("todo.repositories.task_repository.TaskRepository.list")
Expand All @@ -72,7 +78,9 @@ def test_field_specific_defaults_integration(self, mock_list, mock_count):
response = self.client.get("/v1/tasks", {"sort_by": sort_field})

self.assertEqual(response.status_code, status.HTTP_200_OK)
mock_list.assert_called_with(1, 20, sort_field, expected_order, str(self.user_id), team_id=None)
mock_list.assert_called_with(
1, 20, sort_field, expected_order, str(self.user_id), team_id=None, status_filter=None
)

@patch("todo.repositories.task_repository.TaskRepository.count")
@patch("todo.repositories.task_repository.TaskRepository.list")
Expand All @@ -84,7 +92,9 @@ def test_pagination_with_sorting_integration(self, mock_list, mock_count):

self.assertEqual(response.status_code, status.HTTP_200_OK)

mock_list.assert_called_with(3, 5, SORT_FIELD_CREATED_AT, SORT_ORDER_ASC, str(self.user_id), team_id=None)
mock_list.assert_called_with(
3, 5, SORT_FIELD_CREATED_AT, SORT_ORDER_ASC, str(self.user_id), team_id=None, status_filter=None
)

def test_invalid_sort_parameters_integration(self):
response = self.client.get("/v1/tasks", {"sort_by": "invalid_field"})
Expand All @@ -103,7 +113,9 @@ def test_default_behavior_integration(self, mock_list, mock_count):

self.assertEqual(response.status_code, status.HTTP_200_OK)

mock_list.assert_called_with(1, 20, SORT_FIELD_CREATED_AT, SORT_ORDER_DESC, str(self.user_id), team_id=None)
mock_list.assert_called_with(
1, 20, SORT_FIELD_CREATED_AT, SORT_ORDER_DESC, str(self.user_id), team_id=None, status_filter=None
)

@patch("todo.services.task_service.reverse_lazy", return_value="/v1/tasks")
@patch("todo.repositories.task_repository.TaskRepository.count")
Expand Down
16 changes: 14 additions & 2 deletions todo/tests/integration/test_tasks_pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,13 @@ def test_pagination_settings_integration(self, mock_get_tasks):

self.assertEqual(response.status_code, 200)
mock_get_tasks.assert_called_with(
page=1, limit=default_limit, sort_by="createdAt", order="desc", user_id=str(self.user_id), team_id=None
page=1,
limit=default_limit,
sort_by="createdAt",
order="desc",
user_id=str(self.user_id),
team_id=None,
status_filter=None,
)

mock_get_tasks.reset_mock()
Expand All @@ -30,7 +36,13 @@ def test_pagination_settings_integration(self, mock_get_tasks):

self.assertEqual(response.status_code, 200)
mock_get_tasks.assert_called_with(
page=1, limit=10, sort_by="createdAt", order="desc", user_id=str(self.user_id), team_id=None
page=1,
limit=10,
sort_by="createdAt",
order="desc",
user_id=str(self.user_id),
team_id=None,
status_filter=None,
)

# Verify API rejects values above max limit
Expand Down
2 changes: 1 addition & 1 deletion todo/tests/unit/repositories/test_task_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def test_count_returns_total_task_count(self):
result = TaskRepository.count()

self.assertEqual(result, 42)
self.mock_collection.count_documents.assert_called_once_with({})
self.mock_collection.count_documents.assert_called_once_with({"status": {"$ne": "DONE"}})

def test_get_all_returns_all_tasks(self):
self.mock_collection.find.return_value = self.task_data
Expand Down
30 changes: 22 additions & 8 deletions todo/tests/unit/services/test_task_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def test_get_tasks_returns_paginated_response(
response.links.prev, f"{self.mock_reverse_lazy('tasks')}?page=1&limit=1&sort_by=createdAt&order=desc"
)

mock_list.assert_called_once_with(2, 1, "createdAt", "desc", str(self.user_id), team_id=None)
mock_list.assert_called_once_with(
2, 1, "createdAt", "desc", str(self.user_id), team_id=None, status_filter=None
)
mock_count.assert_called_once()

@patch("todo.services.task_service.UserRepository.get_by_id")
Expand Down Expand Up @@ -111,7 +113,7 @@ def test_get_tasks_returns_empty_response_if_no_tasks_present(self, mock_list: M
self.assertEqual(len(response.tasks), 0)
self.assertIsNone(response.links)

mock_list.assert_called_once_with(1, 10, "createdAt", "desc", "test_user", team_id=None)
mock_list.assert_called_once_with(1, 10, "createdAt", "desc", "test_user", team_id=None, status_filter=None)
mock_count.assert_called_once()

@patch("todo.services.task_service.TaskRepository.count")
Expand Down Expand Up @@ -294,7 +296,9 @@ def test_get_tasks_default_sorting(self, mock_list, mock_count):

TaskService.get_tasks(page=1, limit=20, sort_by="createdAt", order="desc", user_id="test_user")

mock_list.assert_called_once_with(1, 20, SORT_FIELD_CREATED_AT, SORT_ORDER_DESC, "test_user", team_id=None)
mock_list.assert_called_once_with(
1, 20, SORT_FIELD_CREATED_AT, SORT_ORDER_DESC, "test_user", team_id=None, status_filter=None
)

@patch("todo.services.task_service.TaskRepository.count")
@patch("todo.services.task_service.TaskRepository.list")
Expand All @@ -304,7 +308,9 @@ def test_get_tasks_explicit_sort_by_priority(self, mock_list, mock_count):

TaskService.get_tasks(page=1, limit=20, sort_by=SORT_FIELD_PRIORITY, order=SORT_ORDER_DESC, user_id="test_user")

mock_list.assert_called_once_with(1, 20, SORT_FIELD_PRIORITY, SORT_ORDER_DESC, "test_user", team_id=None)
mock_list.assert_called_once_with(
1, 20, SORT_FIELD_PRIORITY, SORT_ORDER_DESC, "test_user", team_id=None, status_filter=None
)

@patch("todo.services.task_service.TaskRepository.count")
@patch("todo.services.task_service.TaskRepository.list")
Expand All @@ -314,7 +320,9 @@ def test_get_tasks_sort_by_due_at_default_order(self, mock_list, mock_count):

TaskService.get_tasks(page=1, limit=20, sort_by=SORT_FIELD_DUE_AT, order="asc", user_id="test_user")

mock_list.assert_called_once_with(1, 20, SORT_FIELD_DUE_AT, SORT_ORDER_ASC, "test_user", team_id=None)
mock_list.assert_called_once_with(
1, 20, SORT_FIELD_DUE_AT, SORT_ORDER_ASC, "test_user", team_id=None, status_filter=None
)

@patch("todo.services.task_service.TaskRepository.count")
@patch("todo.services.task_service.TaskRepository.list")
Expand All @@ -324,7 +332,9 @@ def test_get_tasks_sort_by_priority_default_order(self, mock_list, mock_count):

TaskService.get_tasks(page=1, limit=20, sort_by=SORT_FIELD_PRIORITY, order="desc", user_id="test_user")

mock_list.assert_called_once_with(1, 20, SORT_FIELD_PRIORITY, SORT_ORDER_DESC, "test_user", team_id=None)
mock_list.assert_called_once_with(
1, 20, SORT_FIELD_PRIORITY, SORT_ORDER_DESC, "test_user", team_id=None, status_filter=None
)

@patch("todo.services.task_service.TaskRepository.count")
@patch("todo.services.task_service.TaskRepository.list")
Expand All @@ -334,7 +344,9 @@ def test_get_tasks_sort_by_assignee_default_order(self, mock_list, mock_count):

TaskService.get_tasks(page=1, limit=20, sort_by=SORT_FIELD_ASSIGNEE, order="asc", user_id="test_user")

mock_list.assert_called_once_with(1, 20, SORT_FIELD_ASSIGNEE, SORT_ORDER_ASC, "test_user", team_id=None)
mock_list.assert_called_once_with(
1, 20, SORT_FIELD_ASSIGNEE, SORT_ORDER_ASC, "test_user", team_id=None, status_filter=None
)

@patch("todo.services.task_service.TaskRepository.count")
@patch("todo.services.task_service.TaskRepository.list")
Expand All @@ -344,7 +356,9 @@ def test_get_tasks_sort_by_created_at_default_order(self, mock_list, mock_count)

TaskService.get_tasks(page=1, limit=20, sort_by=SORT_FIELD_CREATED_AT, order="desc", user_id="test_user")

mock_list.assert_called_once_with(1, 20, SORT_FIELD_CREATED_AT, SORT_ORDER_DESC, "test_user", team_id=None)
mock_list.assert_called_once_with(
1, 20, SORT_FIELD_CREATED_AT, SORT_ORDER_DESC, "test_user", team_id=None, status_filter=None
)

@patch("todo.services.task_service.reverse_lazy", return_value="/v1/tasks")
def test_build_page_url_includes_sort_parameters(self, mock_reverse):
Expand Down
Loading