Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions backend/apps/chat/managers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from django.db import models
from django.utils.timezone import now

# https://chatgpt.com/share/67fda93c-b758-8005-b5d1-56978a0bda6e
class ChatManager(models.Manager):
def get_or_create_chat(self, user1, user2):
"""Ensure consistent ordering of users in the chat model."""
if user1.id > user2.id:
user1, user2 = user2, user1
chat, created = self.get_or_create(user1=user1, user2=user2)
return chat, created


class MessageManager(models.Manager):
def create_message(self, chat, sender, content):
"""Create a new message and update the chat's updatedAt field."""
# Create the message
message = self.create(chat=chat, sender=sender, content=content)

# Update the chat's updatedAt field
chat.updatedAt = now()
chat.save()

return message
20 changes: 5 additions & 15 deletions backend/apps/chat/models.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,34 @@
from django.db import models
from django.db import models
from django.contrib.auth.models import User
from django.utils.timezone import now
from .managers import ChatManager, MessageManager

class Chat(models.Model):
user1 = models.ForeignKey(User, on_delete=models.CASCADE, related_name="chats_initiated")
user2 = models.ForeignKey(User, on_delete=models.CASCADE, related_name="chats_received")
createdAt = models.DateTimeField(auto_now_add=True)
updatedAt = models.DateTimeField(auto_now=True)

objects = ChatManager()

class Meta:
unique_together = ('user1', 'user2')
ordering = ['-updatedAt']

def __str__(self):
return f"Chat between {self.user1} and {self.user2}"

@classmethod
def get_or_create_chat(cls, user1, user2):
"""Ensure consistent ordering of users in the chat model."""
if user1.id > user2.id:
user1, user2 = user2, user1
chat, created = cls.objects.get_or_create(user1=user1, user2=user2)
return chat

class Message(models.Model):
chat = models.ForeignKey(Chat, on_delete=models.CASCADE, related_name="messages")
sender = models.ForeignKey(User, on_delete=models.CASCADE, related_name="sent_messages")
content = models.TextField()
timestamp = models.DateTimeField(default=now)

objects = MessageManager()

class Meta:
ordering = ['timestamp']

def __str__(self):
return f"From {self.sender} in Chat {self.chat.id} at {self.timestamp}"

def save(self, *args, **kwargs):
"""Update the chat's updatedAt field whenever a new message is saved."""
self.chat.updatedAt = now()
self.chat.save()
super().save(*args, **kwargs)

26 changes: 26 additions & 0 deletions backend/apps/chat/serializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from rest_framework import serializers
from django.contrib.auth.models import User
from .models import Chat, Message

# https://chatgpt.com/share/67fd9c70-d378-8005-8c39-b0453f0f790f
class UserSerializer(serializers.ModelSerializer):
class Meta:
model = User
fields = ['id', 'username']

class MessageSerializer(serializers.ModelSerializer):
sender = UserSerializer(read_only=True)

class Meta:
model = Message
fields = ['id', 'chat', 'sender', 'content', 'timestamp']
read_only_fields = ['timestamp']

class ChatSerializer(serializers.ModelSerializer):
user1 = UserSerializer(read_only=True)
user2 = UserSerializer(read_only=True)
messages = MessageSerializer(many=True, read_only=True)

class Meta:
model = Chat
fields = ['id', 'user1', 'user2', 'createdAt', 'updatedAt', 'messages']
79 changes: 77 additions & 2 deletions backend/apps/chat/tests.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,78 @@
from django.test import TestCase
from django.contrib.auth.models import User
from django.urls import reverse
from rest_framework import status
from rest_framework.test import APITestCase, APIClient
from .models import Chat, Message

# Create your tests here.
class ChatAPITestCase(APITestCase):

def setUp(self):
self.user1 = User.objects.create_user(username='alice', password='pass1234')
self.user2 = User.objects.create_user(username='bob', password='pass1234')
self.user3 = User.objects.create_user(username='charlie', password='pass1234')

self.client = APIClient()
# Obtain JWT tokens
self.login_url = "/users/login/"
response = self.client.post(self.login_url, {"username": "alice", "password": "pass1234"})
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.refresh_token = response.json().get("refreshToken")
self.access_token = response.json().get("accessToken")
self.client.credentials(HTTP_AUTHORIZATION=f"Bearer {self.access_token}")

def test_create_chat(self):
self.assertEqual(Chat.objects.count(), 0)
url = reverse('chat-list')
data = {'user2': self.user2.id}
response = self.client.post(url, data)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(Chat.objects.count(), 1)

def test_prevent_duplicate_chat(self):
Chat.objects.get_or_create_chat(self.user1, self.user2)
url = reverse('chat-list')
data = {'user2': self.user2.id}
response = self.client.post(url, data)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)

def test_list_user_chats(self):
Chat.objects.get_or_create_chat(self.user1, self.user2)
Chat.objects.get_or_create_chat(self.user1, self.user3)
url = reverse('chat-list')
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data), 2)
# ensure that the receivers match
response_receiver_ids = {chat['user2']['id'] for chat in response.data}
expected_receiver_ids = {self.user2.id, self.user3.id}
self.assertSetEqual(response_receiver_ids, expected_receiver_ids)

def test_send_message(self):
chat, _ = Chat.objects.get_or_create_chat(self.user1, self.user2)
url = reverse('message-list')
data = {'chat': chat.id, 'content': 'Hello Bob!'}
response = self.client.post(url, data)
self.assertEqual(response.status_code, status.HTTP_201_CREATED)
self.assertEqual(Message.objects.count(), 1)
message = Message.objects.first()
self.assertEqual(message.content, 'Hello Bob!')
self.assertEqual(message.sender, self.user1)


def test_list_chat_messages(self):
chat, created = Chat.objects.get_or_create_chat(self.user1, self.user2)
Message.objects.create(chat=chat, sender=self.user1, content='Hi Bob!')
Message.objects.create(chat=chat, sender=self.user2, content='Hey Alice!')

url = reverse('chat-messages', kwargs={'pk': chat.id})
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.data), 2)
self.assertEqual(response.data[0]['content'], 'Hi Bob!')
self.assertEqual(response.data[1]['content'], 'Hey Alice!')

def test_unauthenticated_access(self):
self.client.logout()
url = reverse('chat-list')
response = self.client.get(url)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
20 changes: 14 additions & 6 deletions backend/apps/chat/urls.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
# urls.py
from django.urls import path
from .views import select_send_message, send_message, chat_list, chat_detail
from django.urls import path, include
# from .views import select_send_message, send_message, chat_list, chat_detail
# https://chatgpt.com/share/67fd9c70-d378-8005-8c39-b0453f0f790f
from rest_framework.routers import DefaultRouter
from .views import ChatViewSet, MessageViewSet

router = DefaultRouter()
router.register(r'chats', ChatViewSet, basename='chat')
router.register(r'messages', MessageViewSet, basename='message')

urlpatterns = [
path('', chat_list, name='chat_list'),
path('<int:chat_id>/', chat_detail, name='chat_detail'),
path('send/', select_send_message, name='select_send_message'),
path('send/<int:user_id>/', send_message, name='send_message'),
path('', include(router.urls)),
# path('legacy/', chat_list, name='chat_list'),
# path('<int:chat_id>/', chat_detail, name='chat_detail'),
# path('send/', select_send_message, name='select_send_message'),
# path('send/<int:user_id>/', send_message, name='send_message'),
]

143 changes: 100 additions & 43 deletions backend/apps/chat/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,51 +4,108 @@
from .models import Chat, User
from .forms import MessageForm

@login_required
def select_send_message(request):
"""Allow the user to select a recipient for messaging."""
if request.method == "POST":
user_id = request.POST.get("receiver")
return redirect(reverse('send_message', args=[user_id])) # Redirect to send_message view

users = User.objects.exclude(id=request.user.id) # Exclude the logged-in user
return render(request, 'select_send_message.html', {'users': users})

@login_required
def send_message(request, user_id):
receiver = get_object_or_404(User, id=user_id)
chat = Chat.get_or_create_chat(request.user, receiver)

if request.method == "POST":
form = MessageForm(request.POST)
if form.is_valid():
message = form.save(commit=False)
message.sender = request.user
message.chat = chat
message.save()
return redirect('chat_detail', chat_id=chat.id)
# Handle GET request by preloading the message form
else:
form = MessageForm()

return render(request, 'send_message.html', {'form': form, 'receiver': receiver})

@login_required
def chat_list(request):
chats = Chat.objects.filter(user1=request.user) | Chat.objects.filter(user2=request.user)
return render(request, 'chat_list.html', {'chats': chats})

@login_required
def chat_detail(request, chat_id):
chat = get_object_or_404(Chat, id=chat_id)
# https://chatgpt.com/share/67fd9c70-d378-8005-8c39-b0453f0f790f
from rest_framework import viewsets, permissions, serializers
from rest_framework.response import Response
from rest_framework.decorators import action
from django.contrib.auth.models import User
from django.db.models import Q
from .models import Chat, Message
from .serializers import ChatSerializer, MessageSerializer

class ChatViewSet(viewsets.ModelViewSet):
queryset = Chat.objects.all()
serializer_class = ChatSerializer
permission_classes = [permissions.IsAuthenticated]

def get_queryset(self):
"""Restrict to chats where the user is a participant."""
user = self.request.user
return Chat.objects.filter(Q(user1=user) | Q(user2=user)).order_by('-updatedAt')
# return Chat.objects.filter(user1=user).union(Chat.objects.filter(user2=user)).order_by('-updatedAt')

def perform_create(self, serializer):
"""Prevent duplicate chats, enforce user1-user2 ordering."""
user1 = self.request.user
user2_id = self.request.data.get('user2')
if not user2_id:
raise serializers.ValidationError({"user2": "This field is required."})

try:
user2 = User.objects.get(id=user2_id)
except User.DoesNotExist:
raise serializers.ValidationError({"user2": "User not found."})

chat, created = Chat.objects.get_or_create_chat(user1, user2)
if not created:
raise serializers.ValidationError("Chat already exists.")
serializer.instance = chat

@action(detail=True, methods=['get'])
def messages(self, request, pk=None):
"""List messages for a chat."""
chat = self.get_object()
messages = chat.messages.all()
serializer = MessageSerializer(messages, many=True)
return Response(serializer.data)

class MessageViewSet(viewsets.ModelViewSet):
queryset = Message.objects.all()
serializer_class = MessageSerializer
permission_classes = [permissions.IsAuthenticated]

def perform_create(self, serializer):
serializer.save(sender=self.request.user)

# =============================================
# NOTE: ENPOINTS BELOW ARE CONSIDERED LEGACY
# =============================================

# @login_required
# def select_send_message(request):
# """Allow the user to select a recipient for messaging."""
# if request.method == "POST":
# user_id = request.POST.get("receiver")
# return redirect(reverse('send_message', args=[user_id])) # Redirect to send_message view

# users = User.objects.exclude(id=request.user.id) # Exclude the logged-in user
# return render(request, 'select_send_message.html', {'users': users})

# @login_required
# def send_message(request, user_id):
# receiver = get_object_or_404(User, id=user_id)
# chat = Chat.get_or_create_chat(request.user, receiver)

# if request.method == "POST":
# form = MessageForm(request.POST)
# if form.is_valid():
# message = form.save(commit=False)
# message.sender = request.user
# message.chat = chat
# message.save()
# return redirect('chat_detail', chat_id=chat.id)
# # Handle GET request by preloading the message form
# else:
# form = MessageForm()

# return render(request, 'send_message.html', {'form': form, 'receiver': receiver})

# @login_required
# def chat_list(request):
# chats = Chat.objects.filter(user1=request.user) | Chat.objects.filter(user2=request.user)
# return render(request, 'chat_list.html', {'chats': chats})

# @login_required
# def chat_detail(request, chat_id):
# chat = get_object_or_404(Chat, id=chat_id)

# Ensure the logged-in user is part of this chat
if request.user != chat.user1 and request.user != chat.user2:
return redirect('chat_list')
# # Ensure the logged-in user is part of this chat
# if request.user != chat.user1 and request.user != chat.user2:
# return redirect('chat_list')

messages = chat.messages.all()
form = MessageForm()
# messages = chat.messages.all()
# form = MessageForm()

return render(request, 'chat_detail.html', {'chat': chat, 'messages': messages, 'form': form})
# return render(request, 'chat_detail.html', {'chat': chat, 'messages': messages, 'form': form})