Skip to content

Commit f2f99d6

Browse files
author
Jeff Yanta
committed
Flush pointers on chat event stream open
1 parent a2cda00 commit f2f99d6

File tree

5 files changed

+101
-7
lines changed

5 files changed

+101
-7
lines changed

pkg/code/data/chat/v2/memory/store.go

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,17 @@ func (s *store) GetMessageById(_ context.Context, chatId chat.ChatId, messageId
7171
return &cloned, nil
7272
}
7373

74-
// GetAllMessagesByChat implements chat.Store.GetAllMessagesByChat
75-
func (s *store) GetAllMessagesByChat(_ context.Context, chatId chat.ChatId, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*chat.MessageRecord, error) {
74+
// GetAllMembersByChatId implements chat.Store.GetAllMembersByChatId
75+
func (s *store) GetAllMembersByChatId(_ context.Context, chatId chat.ChatId) ([]*chat.MemberRecord, error) {
76+
items := s.findMembersByChatId(chatId)
77+
if len(items) == 0 {
78+
return nil, chat.ErrMemberNotFound
79+
}
80+
return cloneMemberRecords(items), nil
81+
}
82+
83+
// GetAllMessagesByChatId implements chat.Store.GetAllMessagesByChatId
84+
func (s *store) GetAllMessagesByChatId(_ context.Context, chatId chat.ChatId, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*chat.MessageRecord, error) {
7685
s.mu.Lock()
7786
defer s.mu.Unlock()
7887

@@ -81,10 +90,10 @@ func (s *store) GetAllMessagesByChat(_ context.Context, chatId chat.ChatId, curs
8190
if err != nil {
8291
return nil, err
8392
}
93+
8494
if len(items) == 0 {
8595
return nil, chat.ErrMessageNotFound
8696
}
87-
8897
return cloneMessageRecords(items), nil
8998
}
9099

@@ -275,6 +284,16 @@ func (s *store) findMemberById(chatId chat.ChatId, memberId chat.MemberId) *chat
275284
return nil
276285
}
277286

287+
func (s *store) findMembersByChatId(chatId chat.ChatId) []*chat.MemberRecord {
288+
var res []*chat.MemberRecord
289+
for _, item := range s.memberRecords {
290+
if bytes.Equal(chatId[:], item.ChatId[:]) {
291+
res = append(res, item)
292+
}
293+
}
294+
return res
295+
}
296+
278297
func (s *store) findMessage(data *chat.MessageRecord) *chat.MessageRecord {
279298
for _, item := range s.messageRecords {
280299
if data.Id == item.Id {
@@ -362,6 +381,15 @@ func (s *store) reset() {
362381
s.lastMessageId = 0
363382
}
364383

384+
func cloneMemberRecords(items []*chat.MemberRecord) []*chat.MemberRecord {
385+
res := make([]*chat.MemberRecord, len(items))
386+
for i, item := range items {
387+
cloned := item.Clone()
388+
res[i] = &cloned
389+
}
390+
return res
391+
}
392+
365393
func cloneMessageRecords(items []*chat.MessageRecord) []*chat.MessageRecord {
366394
res := make([]*chat.MessageRecord, len(items))
367395
for i, item := range items {

pkg/code/data/chat/v2/model.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,20 @@ func GetPointerTypeFromProto(proto chatpb.Pointer_Kind) PointerType {
119119
}
120120
}
121121

122+
// ToProto returns the proto representation of the pointer type
123+
func (p PointerType) ToProto() chatpb.Pointer_Kind {
124+
switch p {
125+
case PointerTypeSent:
126+
return chatpb.Pointer_SENT
127+
case PointerTypeDelivered:
128+
return chatpb.Pointer_DELIVERED
129+
case PointerTypeRead:
130+
return chatpb.Pointer_READ
131+
default:
132+
return chatpb.Pointer_UNKNOWN
133+
}
134+
}
135+
122136
// String returns the string representation of the pointer type
123137
func (p PointerType) String() string {
124138
switch p {

pkg/code/data/chat/v2/store.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,15 @@ type Store interface {
2828
// GetMessageById gets a chat message by the chat and message IDs
2929
GetMessageById(ctx context.Context, chatId ChatId, messageId MessageId) (*MessageRecord, error)
3030

31-
// GetAllMessagesByChat gets all messages for a given chat
31+
// GetAllMembersByChatId gets all members for a given chat
32+
//
33+
// todo: Add paging when we introduce group chats
34+
GetAllMembersByChatId(ctx context.Context, chatId ChatId) ([]*MemberRecord, error)
35+
36+
// GetAllMessagesByChatId gets all messages for a given chat
3237
//
3338
// Note: Cursor is a message ID
34-
GetAllMessagesByChat(ctx context.Context, chatId ChatId, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*MessageRecord, error)
39+
GetAllMessagesByChatId(ctx context.Context, chatId ChatId, cursor query.Cursor, direction query.Ordering, limit uint64) ([]*MessageRecord, error)
3540

3641
// PutChat creates a new chat
3742
PutChat(ctx context.Context, record *ChatRecord) error

pkg/code/data/internal.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,7 @@ type DatabaseData interface {
399399
GetChatByIdV2(ctx context.Context, chatId chat_v2.ChatId) (*chat_v2.ChatRecord, error)
400400
GetChatMemberByIdV2(ctx context.Context, chatId chat_v2.ChatId, memberId chat_v2.MemberId) (*chat_v2.MemberRecord, error)
401401
GetChatMessageByIdV2(ctx context.Context, chatId chat_v2.ChatId, messageId chat_v2.MessageId) (*chat_v2.MessageRecord, error)
402+
GetAllChatMembersV2(ctx context.Context, chatId chat_v2.ChatId) ([]*chat_v2.MemberRecord, error)
402403
GetAllChatMessagesV2(ctx context.Context, chatId chat_v2.ChatId, opts ...query.Option) ([]*chat_v2.MessageRecord, error)
403404
PutChatV2(ctx context.Context, record *chat_v2.ChatRecord) error
404405
PutChatMemberV2(ctx context.Context, record *chat_v2.MemberRecord) error
@@ -1472,12 +1473,15 @@ func (dp *DatabaseProvider) GetChatMemberByIdV2(ctx context.Context, chatId chat
14721473
func (dp *DatabaseProvider) GetChatMessageByIdV2(ctx context.Context, chatId chat_v2.ChatId, messageId chat_v2.MessageId) (*chat_v2.MessageRecord, error) {
14731474
return dp.chatv2.GetMessageById(ctx, chatId, messageId)
14741475
}
1476+
func (dp *DatabaseProvider) GetAllChatMembersV2(ctx context.Context, chatId chat_v2.ChatId) ([]*chat_v2.MemberRecord, error) {
1477+
return dp.chatv2.GetAllMembersByChatId(ctx, chatId)
1478+
}
14751479
func (dp *DatabaseProvider) GetAllChatMessagesV2(ctx context.Context, chatId chat_v2.ChatId, opts ...query.Option) ([]*chat_v2.MessageRecord, error) {
14761480
req, err := query.DefaultPaginationHandler(opts...)
14771481
if err != nil {
14781482
return nil, err
14791483
}
1480-
return dp.chatv2.GetAllMessagesByChat(ctx, chatId, req.Cursor, req.SortBy, req.Limit)
1484+
return dp.chatv2.GetAllMessagesByChatId(ctx, chatId, req.Cursor, req.SortBy, req.Limit)
14811485
}
14821486
func (dp *DatabaseProvider) PutChatV2(ctx context.Context, record *chat_v2.ChatRecord) error {
14831487
return dp.chatv2.PutChat(ctx, record)

pkg/code/server/grpc/chat/v2/server.go

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,8 +299,8 @@ func (s *server) StreamChatEvents(streamer chatpb.Chat_StreamChatEventsServer) e
299299
sendPingCh := time.After(0)
300300
streamHealthCh := monitorChatEventStreamHealth(ctx, log, streamRef, streamer)
301301

302-
// todo: We should also "flush" pointers for each chat member
303302
go s.flushMessages(ctx, chatId, owner, stream)
303+
go s.flushPointers(ctx, chatId, stream)
304304

305305
for {
306306
select {
@@ -385,6 +385,49 @@ func (s *server) flushMessages(ctx context.Context, chatId chat.ChatId, owner *c
385385
}
386386
}
387387

388+
func (s *server) flushPointers(ctx context.Context, chatId chat.ChatId, stream *chatEventStream) {
389+
log := s.log.WithFields(logrus.Fields{
390+
"method": "flushPointers",
391+
"chat_id": chatId.String(),
392+
})
393+
394+
memberRecords, err := s.data.GetAllChatMembersV2(ctx, chatId)
395+
if err == chat.ErrMemberNotFound {
396+
return
397+
} else if err != nil {
398+
log.WithError(err).Warn("failure getting chat members")
399+
return
400+
}
401+
402+
for _, memberRecord := range memberRecords {
403+
for _, optionalPointer := range []struct {
404+
kind chat.PointerType
405+
value *chat.MessageId
406+
}{
407+
{chat.PointerTypeDelivered, memberRecord.DeliveryPointer},
408+
{chat.PointerTypeRead, memberRecord.ReadPointer},
409+
} {
410+
if optionalPointer.value == nil {
411+
continue
412+
}
413+
414+
event := &chatpb.ChatStreamEvent{
415+
Type: &chatpb.ChatStreamEvent_Pointer{
416+
Pointer: &chatpb.Pointer{
417+
Kind: optionalPointer.kind.ToProto(),
418+
Value: optionalPointer.value.ToProto(),
419+
MemberId: memberRecord.MemberId.ToProto(),
420+
},
421+
},
422+
}
423+
if err := stream.notify(event, streamNotifyTimeout); err != nil {
424+
log.WithError(err).Warnf("failed to notify session stream, closing streamer (stream=%p)", stream)
425+
return
426+
}
427+
}
428+
}
429+
}
430+
388431
func (s *server) SendMessage(ctx context.Context, req *chatpb.SendMessageRequest) (*chatpb.SendMessageResponse, error) {
389432
log := s.log.WithField("method", "SendMessage")
390433
log = client.InjectLoggingMetadata(ctx, log)

0 commit comments

Comments
 (0)