Skip to content

Commit 602a42a

Browse files
committed
adjust websocket
1 parent 7cb8e2f commit 602a42a

File tree

6 files changed

+187
-66
lines changed

6 files changed

+187
-66
lines changed

services/pubsub/memory.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
// Copyright 2024 The Gitea Authors. All rights reserved.
2+
// SPDX-License-Identifier: MIT
3+
4+
package pubsub
5+
6+
import (
7+
"context"
8+
"sync"
9+
)
10+
11+
type Memory struct {
12+
sync.Mutex
13+
14+
topics map[string]map[*Subscriber]struct{}
15+
}
16+
17+
// New creates an in-memory publisher.
18+
func NewMemory() Broker {
19+
return &Memory{
20+
topics: make(map[string]map[*Subscriber]struct{}),
21+
}
22+
}
23+
24+
func (p *Memory) Publish(_ context.Context, message Message) {
25+
p.Lock()
26+
27+
topic, ok := p.topics[message.Topic]
28+
if !ok {
29+
p.Unlock()
30+
return
31+
}
32+
33+
for s := range topic {
34+
go (*s)(message)
35+
}
36+
p.Unlock()
37+
}
38+
39+
func (p *Memory) Subscribe(c context.Context, topic string, subscriber Subscriber) {
40+
// Subscribe
41+
p.Lock()
42+
_, ok := p.topics[topic]
43+
if !ok {
44+
p.topics[topic] = make(map[*Subscriber]struct{})
45+
}
46+
p.topics[topic][&subscriber] = struct{}{}
47+
p.Unlock()
48+
49+
// Wait for context to be done
50+
<-c.Done()
51+
52+
// Unsubscribe
53+
p.Lock()
54+
delete(p.topics[topic], &subscriber)
55+
if len(p.topics[topic]) == 0 {
56+
delete(p.topics, topic)
57+
}
58+
p.Unlock()
59+
}

services/pubsub/memory_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Copyright 2024 The Gitea Authors. All rights reserved.
2+
// SPDX-License-Identifier: MIT
3+
4+
package pubsub
5+
6+
import (
7+
"context"
8+
"sync"
9+
"testing"
10+
"time"
11+
12+
"github.com/stretchr/testify/assert"
13+
)
14+
15+
func TestPubsub(t *testing.T) {
16+
var (
17+
wg sync.WaitGroup
18+
19+
testMessage = Message{
20+
Data: []byte("test"),
21+
Topic: "hello-world",
22+
}
23+
)
24+
25+
ctx, cancel := context.WithCancelCause(
26+
context.Background(),
27+
)
28+
29+
broker := NewMemory()
30+
go func() {
31+
broker.Subscribe(ctx, "hello-world", func(message Message) { assert.Equal(t, testMessage, message); wg.Done() })
32+
}()
33+
go func() {
34+
broker.Subscribe(ctx, "hello-world", func(_ Message) { wg.Done() })
35+
}()
36+
37+
// Wait a bit for the subscriptions to be registered
38+
<-time.After(100 * time.Millisecond)
39+
40+
wg.Add(2)
41+
go func() {
42+
broker.Publish(ctx, testMessage)
43+
}()
44+
45+
wg.Wait()
46+
cancel(nil)
47+
}

services/pubsub/types.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// Copyright 2024 The Gitea Authors. All rights reserved.
2+
// SPDX-License-Identifier: MIT
3+
4+
package pubsub
5+
6+
import "context"
7+
8+
// Message defines a published message.
9+
type Message struct {
10+
// Data is the actual data in the entry.
11+
Data []byte `json:"data"`
12+
13+
// Topic is the topic of the message.
14+
Topic string `json:"topic"`
15+
}
16+
17+
// Subscriber receives published messages.
18+
type Subscriber func(Message)
19+
20+
type Broker interface {
21+
Publish(c context.Context, message Message)
22+
Subscribe(c context.Context, topic string, subscriber Subscriber)
23+
}

services/websocket/issue_comment_notifier.go

Lines changed: 10 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5,45 +5,22 @@ package websocket
55

66
import (
77
"context"
8+
"encoding/json"
89
"fmt"
910

1011
issues_model "code.gitea.io/gitea/models/issues"
11-
"code.gitea.io/gitea/models/perm"
12-
"code.gitea.io/gitea/models/perm/access"
13-
repo_model "code.gitea.io/gitea/models/repo"
14-
"code.gitea.io/gitea/models/unit"
1512
user_model "code.gitea.io/gitea/models/user"
16-
"code.gitea.io/gitea/modules/log"
17-
18-
"github.com/olahol/melody"
13+
"code.gitea.io/gitea/services/pubsub"
1914
)
2015

21-
func (n *websocketNotifier) filterIssueSessions(ctx context.Context, repo *repo_model.Repository, issue *issues_model.Issue) []*melody.Session {
22-
return n.filterSessions(func(s *melody.Session, data *sessionData) bool {
23-
// if the user is watching the issue, they will get notifications
24-
if !data.isOnURL(fmt.Sprintf("/%s/%s/issues/%d", repo.Owner.Name, repo.Name, issue.Index)) {
25-
return false
26-
}
27-
28-
// the user will get notifications if they have access to the repos issues
29-
hasAccess, err := access.HasAccessUnit(ctx, data.user, repo, unit.TypeIssues, perm.AccessModeRead)
30-
if err != nil {
31-
log.Error("Failed to check access: %v", err)
32-
return false
33-
}
34-
35-
return hasAccess
36-
})
37-
}
38-
3916
func (n *websocketNotifier) DeleteComment(ctx context.Context, doer *user_model.User, c *issues_model.Comment) {
40-
sessions := n.filterIssueSessions(ctx, c.Issue.Repo, c.Issue)
41-
42-
for _, s := range sessions {
43-
msg := fmt.Sprintf(htmxRemoveElement, fmt.Sprintf("#%s", c.HashTag()))
44-
err := s.Write([]byte(msg))
45-
if err != nil {
46-
log.Error("Failed to write to session: %v", err)
47-
}
17+
d, err := json.Marshal(c)
18+
if err != nil {
19+
return
4820
}
21+
22+
n.pubsub.Publish(ctx, pubsub.Message{
23+
Data: d,
24+
Topic: fmt.Sprintf("repo:%s/%s", c.RefRepo.OwnerName, c.RefRepo.Name),
25+
})
4926
}

services/websocket/notifier.go

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,18 @@
44
package websocket
55

66
import (
7-
"code.gitea.io/gitea/modules/log"
87
"code.gitea.io/gitea/modules/templates"
98
notify_service "code.gitea.io/gitea/services/notify"
9+
"code.gitea.io/gitea/services/pubsub"
1010

1111
"github.com/olahol/melody"
1212
)
1313

1414
type websocketNotifier struct {
1515
notify_service.NullNotifier
16-
m *melody.Melody
17-
rnd *templates.HTMLRender
16+
m *melody.Melody
17+
rnd *templates.HTMLRender
18+
pubsub pubsub.Broker
1819
}
1920

2021
// NewNotifier create a new webhooksNotifier notifier
@@ -29,25 +30,3 @@ func newNotifier(m *melody.Melody) notify_service.Notifier {
2930
// htmxUpdateElement = "<div hx-swap-oob=\"outerHTML:%s\">%s</div>"
3031

3132
var htmxRemoveElement = "<div hx-swap-oob=\"delete:%s\"></div>"
32-
33-
func (n *websocketNotifier) filterSessions(fn func(*melody.Session, *sessionData) bool) []*melody.Session {
34-
sessions, err := n.m.Sessions()
35-
if err != nil {
36-
log.Error("Failed to get sessions: %v", err)
37-
return nil
38-
}
39-
40-
_sessions := make([]*melody.Session, 0, len(sessions))
41-
for _, s := range sessions {
42-
data, err := getSessionData(s)
43-
if err != nil {
44-
continue
45-
}
46-
47-
if fn(s, data) {
48-
_sessions = append(_sessions, s)
49-
}
50-
}
51-
52-
return _sessions
53-
}

services/websocket/websocket.go

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,17 @@
44
package websocket
55

66
import (
7+
goContext "context"
8+
"fmt"
9+
10+
"code.gitea.io/gitea/models/perm"
11+
"code.gitea.io/gitea/models/perm/access"
12+
"code.gitea.io/gitea/models/unit"
713
"code.gitea.io/gitea/modules/json"
14+
"code.gitea.io/gitea/modules/log"
815
"code.gitea.io/gitea/services/context"
916
notify_service "code.gitea.io/gitea/services/notify"
17+
"code.gitea.io/gitea/services/pubsub"
1018

1119
"github.com/mitchellh/mapstructure"
1220
"github.com/olahol/melody"
@@ -20,19 +28,26 @@ type websocketMessage struct {
2028
}
2129

2230
type subscribeMessageData struct {
23-
URL string `json:"url"`
31+
Repo string `json:"repo"`
2432
}
2533

2634
func Init() *melody.Melody {
2735
m = melody.New()
28-
m.HandleConnect(handleConnect)
29-
m.HandleMessage(handleMessage)
36+
hub := &hub{
37+
pubsub: pubsub.NewMemory(),
38+
}
39+
m.HandleConnect(hub.handleConnect)
40+
m.HandleMessage(hub.handleMessage)
3041
m.HandleDisconnect(handleDisconnect)
3142
notify_service.RegisterNotifier(newNotifier(m))
3243
return m
3344
}
3445

35-
func handleConnect(s *melody.Session) {
46+
type hub struct {
47+
pubsub pubsub.Broker
48+
}
49+
50+
func (h *hub) handleConnect(s *melody.Session) {
3651
ctx := context.GetWebContext(s.Request)
3752

3853
data := &sessionData{}
@@ -45,7 +60,7 @@ func handleConnect(s *melody.Session) {
4560
// TODO: handle logouts
4661
}
4762

48-
func handleMessage(s *melody.Session, _msg []byte) {
63+
func (h *hub) handleMessage(s *melody.Session, _msg []byte) {
4964
data, err := getSessionData(s)
5065
if err != nil {
5166
return
@@ -59,21 +74,42 @@ func handleMessage(s *melody.Session, _msg []byte) {
5974

6075
switch msg.Action {
6176
case "subscribe":
62-
err := handleSubscribeMessage(data, msg.Data)
77+
err := h.handleSubscribeMessage(s, data, msg.Data)
6378
if err != nil {
6479
return
6580
}
6681
}
6782
}
6883

69-
func handleSubscribeMessage(data *sessionData, _data any) error {
84+
func (h *hub) handleSubscribeMessage(s *melody.Session, data *sessionData, _data any) error {
7085
msgData := &subscribeMessageData{}
7186
err := mapstructure.Decode(_data, &msgData)
7287
if err != nil {
7388
return err
7489
}
7590

76-
data.onURL = msgData.URL
91+
ctx := goContext.Background() // TODO: use proper context
92+
h.pubsub.Subscribe(ctx, msgData.Repo, func(msg pubsub.Message) {
93+
if data.user != nil {
94+
return
95+
}
96+
97+
// TODO: check permissions
98+
hasAccess, err := access.HasAccessUnit(ctx, data.user, repo, unit.TypeIssues, perm.AccessModeRead)
99+
if err != nil {
100+
log.Error("Failed to check access: %v", err)
101+
return
102+
}
103+
104+
if !hasAccess {
105+
return
106+
}
107+
108+
// TODO: check the actual data received from pubsub and send it correctly to the client
109+
d := fmt.Sprintf(htmxRemoveElement, fmt.Sprintf("#%s", c.HashTag()))
110+
_ = s.Write([]byte(d))
111+
})
112+
77113
return nil
78114
}
79115

0 commit comments

Comments
 (0)