diff --git a/backend/apps/chat/managers.py b/backend/apps/chat/managers.py index e69de29b..5fc424ed 100644 --- a/backend/apps/chat/managers.py +++ b/backend/apps/chat/managers.py @@ -0,0 +1,24 @@ +from django.db import models +from django.utils.timezone import now + +# https://chatgpt.com/share/67fda93c-b758-8005-b5d1-56978a0bda6e +class ChatManager(models.Manager): + def get_or_create_chat(self, user1, user2): + """Ensure consistent ordering of users in the chat model.""" + if user1.id > user2.id: + user1, user2 = user2, user1 + chat, created = self.get_or_create(user1=user1, user2=user2) + return chat, created + + +class MessageManager(models.Manager): + def create_message(self, chat, sender, content): + """Create a new message and update the chat's updatedAt field.""" + # Create the message + message = self.create(chat=chat, sender=sender, content=content) + + # Update the chat's updatedAt field + chat.updatedAt = now() + chat.save() + + return message diff --git a/backend/apps/chat/models.py b/backend/apps/chat/models.py index 98a4a1ce..b9d1fbf7 100644 --- a/backend/apps/chat/models.py +++ b/backend/apps/chat/models.py @@ -1,7 +1,7 @@ from django.db import models -from django.db import models from django.contrib.auth.models import User from django.utils.timezone import now +from .managers import ChatManager, MessageManager class Chat(models.Model): user1 = models.ForeignKey(User, on_delete=models.CASCADE, related_name="chats_initiated") @@ -9,6 +9,8 @@ class Chat(models.Model): createdAt = models.DateTimeField(auto_now_add=True) updatedAt = models.DateTimeField(auto_now=True) + objects = ChatManager() + class Meta: unique_together = ('user1', 'user2') ordering = ['-updatedAt'] @@ -16,29 +18,17 @@ class Meta: def __str__(self): return f"Chat between {self.user1} and {self.user2}" - @classmethod - def get_or_create_chat(cls, user1, user2): - """Ensure consistent ordering of users in the chat model.""" - if user1.id > user2.id: - user1, user2 = user2, user1 - chat, created = cls.objects.get_or_create(user1=user1, user2=user2) - return chat - class Message(models.Model): chat = models.ForeignKey(Chat, on_delete=models.CASCADE, related_name="messages") sender = models.ForeignKey(User, on_delete=models.CASCADE, related_name="sent_messages") content = models.TextField() timestamp = models.DateTimeField(default=now) + objects = MessageManager() + class Meta: ordering = ['timestamp'] def __str__(self): return f"From {self.sender} in Chat {self.chat.id} at {self.timestamp}" - def save(self, *args, **kwargs): - """Update the chat's updatedAt field whenever a new message is saved.""" - self.chat.updatedAt = now() - self.chat.save() - super().save(*args, **kwargs) - diff --git a/backend/apps/chat/serializers.py b/backend/apps/chat/serializers.py index e69de29b..13209af2 100644 --- a/backend/apps/chat/serializers.py +++ b/backend/apps/chat/serializers.py @@ -0,0 +1,26 @@ +from rest_framework import serializers +from django.contrib.auth.models import User +from .models import Chat, Message + +# https://chatgpt.com/share/67fd9c70-d378-8005-8c39-b0453f0f790f +class UserSerializer(serializers.ModelSerializer): + class Meta: + model = User + fields = ['id', 'username'] + +class MessageSerializer(serializers.ModelSerializer): + sender = UserSerializer(read_only=True) + + class Meta: + model = Message + fields = ['id', 'chat', 'sender', 'content', 'timestamp'] + read_only_fields = ['timestamp'] + +class ChatSerializer(serializers.ModelSerializer): + user1 = UserSerializer(read_only=True) + user2 = UserSerializer(read_only=True) + messages = MessageSerializer(many=True, read_only=True) + + class Meta: + model = Chat + fields = ['id', 'user1', 'user2', 'createdAt', 'updatedAt', 'messages'] diff --git a/backend/apps/chat/tests.py b/backend/apps/chat/tests.py index 7ce503c2..0dbbe216 100644 --- a/backend/apps/chat/tests.py +++ b/backend/apps/chat/tests.py @@ -1,3 +1,78 @@ -from django.test import TestCase +from django.contrib.auth.models import User +from django.urls import reverse +from rest_framework import status +from rest_framework.test import APITestCase, APIClient +from .models import Chat, Message -# Create your tests here. +class ChatAPITestCase(APITestCase): + + def setUp(self): + self.user1 = User.objects.create_user(username='alice', password='pass1234') + self.user2 = User.objects.create_user(username='bob', password='pass1234') + self.user3 = User.objects.create_user(username='charlie', password='pass1234') + + self.client = APIClient() + # Obtain JWT tokens + self.login_url = "/users/login/" + response = self.client.post(self.login_url, {"username": "alice", "password": "pass1234"}) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.refresh_token = response.json().get("refreshToken") + self.access_token = response.json().get("accessToken") + self.client.credentials(HTTP_AUTHORIZATION=f"Bearer {self.access_token}") + + def test_create_chat(self): + self.assertEqual(Chat.objects.count(), 0) + url = reverse('chat-list') + data = {'user2': self.user2.id} + response = self.client.post(url, data) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(Chat.objects.count(), 1) + + def test_prevent_duplicate_chat(self): + Chat.objects.get_or_create_chat(self.user1, self.user2) + url = reverse('chat-list') + data = {'user2': self.user2.id} + response = self.client.post(url, data) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + def test_list_user_chats(self): + Chat.objects.get_or_create_chat(self.user1, self.user2) + Chat.objects.get_or_create_chat(self.user1, self.user3) + url = reverse('chat-list') + response = self.client.get(url) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data), 2) + # ensure that the receivers match + response_receiver_ids = {chat['user2']['id'] for chat in response.data} + expected_receiver_ids = {self.user2.id, self.user3.id} + self.assertSetEqual(response_receiver_ids, expected_receiver_ids) + + def test_send_message(self): + chat, _ = Chat.objects.get_or_create_chat(self.user1, self.user2) + url = reverse('message-list') + data = {'chat': chat.id, 'content': 'Hello Bob!'} + response = self.client.post(url, data) + self.assertEqual(response.status_code, status.HTTP_201_CREATED) + self.assertEqual(Message.objects.count(), 1) + message = Message.objects.first() + self.assertEqual(message.content, 'Hello Bob!') + self.assertEqual(message.sender, self.user1) + + + def test_list_chat_messages(self): + chat, created = Chat.objects.get_or_create_chat(self.user1, self.user2) + Message.objects.create(chat=chat, sender=self.user1, content='Hi Bob!') + Message.objects.create(chat=chat, sender=self.user2, content='Hey Alice!') + + url = reverse('chat-messages', kwargs={'pk': chat.id}) + response = self.client.get(url) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.data), 2) + self.assertEqual(response.data[0]['content'], 'Hi Bob!') + self.assertEqual(response.data[1]['content'], 'Hey Alice!') + + def test_unauthenticated_access(self): + self.client.logout() + url = reverse('chat-list') + response = self.client.get(url) + self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) diff --git a/backend/apps/chat/urls.py b/backend/apps/chat/urls.py index b3ba8702..341177f0 100644 --- a/backend/apps/chat/urls.py +++ b/backend/apps/chat/urls.py @@ -1,11 +1,19 @@ # urls.py -from django.urls import path -from .views import select_send_message, send_message, chat_list, chat_detail +from django.urls import path, include +# from .views import select_send_message, send_message, chat_list, chat_detail +# https://chatgpt.com/share/67fd9c70-d378-8005-8c39-b0453f0f790f +from rest_framework.routers import DefaultRouter +from .views import ChatViewSet, MessageViewSet + +router = DefaultRouter() +router.register(r'chats', ChatViewSet, basename='chat') +router.register(r'messages', MessageViewSet, basename='message') urlpatterns = [ - path('', chat_list, name='chat_list'), - path('/', chat_detail, name='chat_detail'), - path('send/', select_send_message, name='select_send_message'), - path('send//', send_message, name='send_message'), + path('', include(router.urls)), + # path('legacy/', chat_list, name='chat_list'), + # path('/', chat_detail, name='chat_detail'), + # path('send/', select_send_message, name='select_send_message'), + # path('send//', send_message, name='send_message'), ] diff --git a/backend/apps/chat/views.py b/backend/apps/chat/views.py index e67eed59..5ca28b9a 100644 --- a/backend/apps/chat/views.py +++ b/backend/apps/chat/views.py @@ -4,51 +4,108 @@ from .models import Chat, User from .forms import MessageForm -@login_required -def select_send_message(request): - """Allow the user to select a recipient for messaging.""" - if request.method == "POST": - user_id = request.POST.get("receiver") - return redirect(reverse('send_message', args=[user_id])) # Redirect to send_message view - - users = User.objects.exclude(id=request.user.id) # Exclude the logged-in user - return render(request, 'select_send_message.html', {'users': users}) - -@login_required -def send_message(request, user_id): - receiver = get_object_or_404(User, id=user_id) - chat = Chat.get_or_create_chat(request.user, receiver) - - if request.method == "POST": - form = MessageForm(request.POST) - if form.is_valid(): - message = form.save(commit=False) - message.sender = request.user - message.chat = chat - message.save() - return redirect('chat_detail', chat_id=chat.id) - # Handle GET request by preloading the message form - else: - form = MessageForm() - - return render(request, 'send_message.html', {'form': form, 'receiver': receiver}) - -@login_required -def chat_list(request): - chats = Chat.objects.filter(user1=request.user) | Chat.objects.filter(user2=request.user) - return render(request, 'chat_list.html', {'chats': chats}) - -@login_required -def chat_detail(request, chat_id): - chat = get_object_or_404(Chat, id=chat_id) +# https://chatgpt.com/share/67fd9c70-d378-8005-8c39-b0453f0f790f +from rest_framework import viewsets, permissions, serializers +from rest_framework.response import Response +from rest_framework.decorators import action +from django.contrib.auth.models import User +from django.db.models import Q +from .models import Chat, Message +from .serializers import ChatSerializer, MessageSerializer + +class ChatViewSet(viewsets.ModelViewSet): + queryset = Chat.objects.all() + serializer_class = ChatSerializer + permission_classes = [permissions.IsAuthenticated] + + def get_queryset(self): + """Restrict to chats where the user is a participant.""" + user = self.request.user + return Chat.objects.filter(Q(user1=user) | Q(user2=user)).order_by('-updatedAt') + # return Chat.objects.filter(user1=user).union(Chat.objects.filter(user2=user)).order_by('-updatedAt') + + def perform_create(self, serializer): + """Prevent duplicate chats, enforce user1-user2 ordering.""" + user1 = self.request.user + user2_id = self.request.data.get('user2') + if not user2_id: + raise serializers.ValidationError({"user2": "This field is required."}) + + try: + user2 = User.objects.get(id=user2_id) + except User.DoesNotExist: + raise serializers.ValidationError({"user2": "User not found."}) + + chat, created = Chat.objects.get_or_create_chat(user1, user2) + if not created: + raise serializers.ValidationError("Chat already exists.") + serializer.instance = chat + + @action(detail=True, methods=['get']) + def messages(self, request, pk=None): + """List messages for a chat.""" + chat = self.get_object() + messages = chat.messages.all() + serializer = MessageSerializer(messages, many=True) + return Response(serializer.data) + +class MessageViewSet(viewsets.ModelViewSet): + queryset = Message.objects.all() + serializer_class = MessageSerializer + permission_classes = [permissions.IsAuthenticated] + + def perform_create(self, serializer): + serializer.save(sender=self.request.user) + +# ============================================= +# NOTE: ENPOINTS BELOW ARE CONSIDERED LEGACY +# ============================================= + +# @login_required +# def select_send_message(request): +# """Allow the user to select a recipient for messaging.""" +# if request.method == "POST": +# user_id = request.POST.get("receiver") +# return redirect(reverse('send_message', args=[user_id])) # Redirect to send_message view + +# users = User.objects.exclude(id=request.user.id) # Exclude the logged-in user +# return render(request, 'select_send_message.html', {'users': users}) + +# @login_required +# def send_message(request, user_id): +# receiver = get_object_or_404(User, id=user_id) +# chat = Chat.get_or_create_chat(request.user, receiver) + +# if request.method == "POST": +# form = MessageForm(request.POST) +# if form.is_valid(): +# message = form.save(commit=False) +# message.sender = request.user +# message.chat = chat +# message.save() +# return redirect('chat_detail', chat_id=chat.id) +# # Handle GET request by preloading the message form +# else: +# form = MessageForm() + +# return render(request, 'send_message.html', {'form': form, 'receiver': receiver}) + +# @login_required +# def chat_list(request): +# chats = Chat.objects.filter(user1=request.user) | Chat.objects.filter(user2=request.user) +# return render(request, 'chat_list.html', {'chats': chats}) + +# @login_required +# def chat_detail(request, chat_id): +# chat = get_object_or_404(Chat, id=chat_id) - # Ensure the logged-in user is part of this chat - if request.user != chat.user1 and request.user != chat.user2: - return redirect('chat_list') +# # Ensure the logged-in user is part of this chat +# if request.user != chat.user1 and request.user != chat.user2: +# return redirect('chat_list') - messages = chat.messages.all() - form = MessageForm() +# messages = chat.messages.all() +# form = MessageForm() - return render(request, 'chat_detail.html', {'chat': chat, 'messages': messages, 'form': form}) +# return render(request, 'chat_detail.html', {'chat': chat, 'messages': messages, 'form': form})