Skip to content

Commit 74852ba

Browse files
author
Jeff Yanta
committed
Fix unread count to not count messages sent by the reader
1 parent 26bf2b7 commit 74852ba

File tree

4 files changed

+18
-7
lines changed

4 files changed

+18
-7
lines changed

pkg/code/data/chat/v2/memory/store.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,13 @@ func (s *store) GetAllMembersByPlatformId(_ context.Context, platform chat.Platf
9898
}
9999

100100
// GetUnreadCount implements chat.store.GetUnreadCount
101-
func (s *store) GetUnreadCount(_ context.Context, chatId chat.ChatId, readPointer chat.MessageId) (uint32, error) {
101+
func (s *store) GetUnreadCount(_ context.Context, chatId chat.ChatId, memberId chat.MemberId, readPointer chat.MessageId) (uint32, error) {
102102
s.mu.Lock()
103103
defer s.mu.Unlock()
104104

105105
items := s.findMessagesByChatId(chatId)
106106
items = s.filterMessagesAfter(items, readPointer)
107+
items = s.filterMessagesNotSentBy(items, memberId)
107108
items = s.filterNotifiedMessages(items)
108109
return uint32(len(items)), nil
109110
}
@@ -414,6 +415,16 @@ func (s *store) filterMessagesAfter(items []*chat.MessageRecord, pointer chat.Me
414415
return res
415416
}
416417

418+
func (s *store) filterMessagesNotSentBy(items []*chat.MessageRecord, sender chat.MemberId) []*chat.MessageRecord {
419+
var res []*chat.MessageRecord
420+
for _, item := range items {
421+
if item.Sender == nil || !bytes.Equal(item.Sender[:], sender[:]) {
422+
res = append(res, item)
423+
}
424+
}
425+
return res
426+
}
427+
417428
func (s *store) filterNotifiedMessages(items []*chat.MessageRecord) []*chat.MessageRecord {
418429
var res []*chat.MessageRecord
419430
for _, item := range items {

pkg/code/data/chat/v2/store.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ type Store interface {
4242
// Note: Cursor is a message ID
4343
GetAllMessagesByChatId(ctx context.Context, chatId ChatId, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*MessageRecord, error)
4444

45-
// GetUnreadCount gets the unread message count for a chat ID at a read pointer
46-
GetUnreadCount(ctx context.Context, chatId ChatId, readPointer MessageId) (uint32, error)
45+
// GetUnreadCount gets the unread message count for a chat ID at a read pointer for a given chat member
46+
GetUnreadCount(ctx context.Context, chatId ChatId, memberId MemberId, readPointer MessageId) (uint32, error)
4747

4848
// PutChat creates a new chat
4949
PutChat(ctx context.Context, record *ChatRecord) error

pkg/code/data/internal.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ type DatabaseData interface {
402402
GetAllChatMembersV2(ctx context.Context, chatId chat_v2.ChatId) ([]*chat_v2.MemberRecord, error)
403403
GetPlatformUserChatMembershipV2(ctx context.Context, platform chat_v2.Platform, platformId string, opts ...query.Option) ([]*chat_v2.MemberRecord, error)
404404
GetAllChatMessagesV2(ctx context.Context, chatId chat_v2.ChatId, opts ...query.Option) ([]*chat_v2.MessageRecord, error)
405-
GetChatUnreadCountV2(ctx context.Context, chatId chat_v2.ChatId, readPointer chat_v2.MessageId) (uint32, error)
405+
GetChatUnreadCountV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, readPointer chat_v2.MessageId) (uint32, error)
406406
PutChatV2(ctx context.Context, record *chat_v2.ChatRecord) error
407407
PutChatMemberV2(ctx context.Context, record *chat_v2.MemberRecord) error
408408
PutChatMessageV2(ctx context.Context, record *chat_v2.MessageRecord) error
@@ -1492,8 +1492,8 @@ func (dp *DatabaseProvider) GetAllChatMessagesV2(ctx context.Context, chatId cha
14921492
}
14931493
return dp.chatv2.GetAllMessagesByChatId(ctx, chatId, req.Cursor, req.SortBy, req.Limit)
14941494
}
1495-
func (dp *DatabaseProvider) GetChatUnreadCountV2(ctx context.Context, chatId chat_v2.ChatId, readPointer chat_v2.MessageId) (uint32, error) {
1496-
return dp.chatv2.GetUnreadCount(ctx, chatId, readPointer)
1495+
func (dp *DatabaseProvider) GetChatUnreadCountV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId, readPointer chat_v2.MessageId) (uint32, error) {
1496+
return dp.chatv2.GetUnreadCount(ctx, chatId, memberId, readPointer)
14971497
}
14981498
func (dp *DatabaseProvider) PutChatV2(ctx context.Context, record *chat_v2.ChatRecord) error {
14991499
return dp.chatv2.PutChat(ctx, record)

pkg/code/server/grpc/chat/v2/server.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ func (s *server) GetChats(ctx context.Context, req *chatpb.GetChatsRequest) (*ch
255255
if platformUserMemberRecord.ReadPointer != nil {
256256
readPointer = *platformUserMemberRecord.ReadPointer
257257
}
258-
unreadCount, err := s.data.GetChatUnreadCountV2(ctx, chatRecord.ChatId, readPointer)
258+
unreadCount, err := s.data.GetChatUnreadCountV2(ctx, chatRecord.ChatId, platformUserMemberRecord.MemberId, readPointer)
259259
if err != nil {
260260
log.WithError(err).Warn("failure getting unread count")
261261
return nil, status.Error(codes.Internal, "")

0 commit comments

Comments
 (0)