Skip to content

Commit 6a8c955

Browse files
authored
Merge pull request #79 from matrix-org/travis/server-communities/02-text-api
Support server-centric communities being able to check text content against filters
2 parents 24f5fdf + 4ca8e66 commit 6a8c955

17 files changed

+468
-33
lines changed

api/api.go

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
package api
22

33
import (
4+
"context"
45
"log"
56
"net/http"
7+
"time"
68

9+
"github.com/matrix-org/policyserv/community"
710
"github.com/matrix-org/policyserv/homeserver"
811
"github.com/matrix-org/policyserv/metrics"
912
"github.com/matrix-org/policyserv/storage"
@@ -16,18 +19,20 @@ type Config struct {
1619
}
1720

1821
type Api struct {
19-
storage storage.PersistentStorage
20-
hs *homeserver.Homeserver
21-
apiKey string
22-
joinViaServer string
22+
storage storage.PersistentStorage
23+
hs *homeserver.Homeserver
24+
communityManager *community.Manager
25+
apiKey string
26+
joinViaServer string
2327
}
2428

25-
func NewApi(config *Config, storage storage.PersistentStorage, hs *homeserver.Homeserver) (*Api, error) {
29+
func NewApi(config *Config, storage storage.PersistentStorage, hs *homeserver.Homeserver, communityManager *community.Manager) (*Api, error) {
2630
return &Api{
27-
storage: storage,
28-
hs: hs,
29-
apiKey: config.ApiKey,
30-
joinViaServer: config.JoinViaServer,
31+
storage: storage,
32+
hs: hs,
33+
communityManager: communityManager,
34+
apiKey: config.ApiKey,
35+
joinViaServer: config.JoinViaServer,
3136
}, nil
3237
}
3338

@@ -49,11 +54,46 @@ func (a *Api) httpAuthenticatedRequestHandler(upstream func(api *Api, w http.Res
4954
})
5055
}
5156

57+
func (a *Api) httpCommunityAuthenticatedRequestHandler(upstream func(api *Api, community *storage.StoredCommunity, w http.ResponseWriter, r *http.Request)) http.Handler {
58+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
59+
authHeader := r.Header.Get("Authorization")
60+
if len(authHeader) <= len("Bearer ") {
61+
defer metrics.RecordHttpResponse(r.Method, "httpCommunityAuthenticatedRequestHandler", http.StatusUnauthorized)
62+
homeserver.MatrixHttpError(w, http.StatusUnauthorized, "M_UNAUTHORIZED", "Not allowed")
63+
return
64+
}
65+
66+
// Set a quick timeout that only affects the community lookup/authentication
67+
fastContext, cancel := context.WithTimeout(r.Context(), 2*time.Second)
68+
defer cancel()
69+
70+
accessToken := authHeader[len("Bearer "):]
71+
community, err := a.storage.GetCommunityByAccessToken(fastContext, accessToken)
72+
if err != nil {
73+
log.Println(err)
74+
defer metrics.RecordHttpResponse(r.Method, "httpCommunityAuthenticatedRequestHandler", http.StatusInternalServerError)
75+
homeserver.MatrixHttpError(w, http.StatusInternalServerError, "M_UNKNOWN", "Server error")
76+
return
77+
}
78+
if community == nil {
79+
defer metrics.RecordHttpResponse(r.Method, "httpCommunityAuthenticatedRequestHandler", http.StatusUnauthorized)
80+
homeserver.MatrixHttpError(w, http.StatusUnauthorized, "M_UNAUTHORIZED", "Not allowed")
81+
return
82+
}
83+
84+
upstream(a, community, w, r)
85+
})
86+
}
87+
5288
func (a *Api) BindTo(mux *http.ServeMux) error {
5389
mux.Handle("/", a.httpRequestHandler(httpCatchAll))
5490
mux.Handle("/health", a.httpRequestHandler(httpHealth))
5591
mux.Handle("/ready", a.httpRequestHandler(httpReady))
5692

93+
// Server-centric community API
94+
mux.Handle("/_policyserv/v1/check/text", a.httpCommunityAuthenticatedRequestHandler(httpCheckTextCommunityApi))
95+
96+
// Admin API
5797
if a.apiKey != "" {
5898
log.Println("Enabling policyserv API")
5999
mux.Handle("/api/v1/rooms", a.httpAuthenticatedRequestHandler(httpGetRoomsApi))
@@ -63,6 +103,7 @@ func (a *Api) BindTo(mux *http.ServeMux) error {
63103
mux.Handle("/api/v1/communities/new", a.httpAuthenticatedRequestHandler(httpCreateCommunityApi))
64104
mux.Handle("/api/v1/communities/{id}", a.httpAuthenticatedRequestHandler(httpGetCommunityApi))
65105
mux.Handle("/api/v1/communities/{id}/config", a.httpAuthenticatedRequestHandler(httpSetCommunityConfigApi))
106+
mux.Handle("/api/v1/communities/{id}/rotate_access_token", a.httpAuthenticatedRequestHandler(httpRotateCommunityAccessTokenApi))
66107
mux.Handle("/api/v1/instance/community_config", a.httpAuthenticatedRequestHandler(httpGetInstanceConfigApi))
67108
mux.Handle("/api/v1/sources/muninn/set_member_directory_event", a.httpAuthenticatedRequestHandler(httpSetMuninnSourceData))
68109
mux.Handle("/api/v1/keyword_templates/{name}", a.httpAuthenticatedRequestHandler(httpKeywordTemplates))

api/api_test.go

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package api
22

33
import (
4+
"context"
45
"crypto/ed25519"
56
"net/http"
67
"net/http/httptest"
@@ -9,7 +10,9 @@ import (
910
"github.com/matrix-org/policyserv/community"
1011
"github.com/matrix-org/policyserv/config"
1112
"github.com/matrix-org/policyserv/homeserver"
13+
"github.com/matrix-org/policyserv/internal"
1214
"github.com/matrix-org/policyserv/queue"
15+
"github.com/matrix-org/policyserv/storage"
1316
"github.com/matrix-org/policyserv/test"
1417
"github.com/stretchr/testify/assert"
1518
)
@@ -53,7 +56,7 @@ func makeApi(t *testing.T) *Api {
5356

5457
api, err := NewApi(&Config{
5558
ApiKey: testApiKey,
56-
}, db, hs)
59+
}, db, hs, communityManager)
5760
assert.NoError(t, err)
5861
assert.NotNil(t, api)
5962

@@ -112,3 +115,68 @@ func TestAuthenticatedApi(t *testing.T) {
112115
assert.Equal(t, w.Code, http.StatusOK)
113116
assert.True(t, called)
114117
}
118+
119+
func TestCommunityAuthenticatedApiNoAuth(t *testing.T) {
120+
t.Parallel()
121+
122+
api := makeApi(t)
123+
124+
w := httptest.NewRecorder()
125+
r := httptest.NewRequest(http.MethodGet, "/example", nil)
126+
//r.Header.Set("Authorization", "Bearer WRONG_TOKEN") // we don't want auth on this test, so don't set it
127+
upstream := func(a *Api, c *storage.StoredCommunity, w http.ResponseWriter, r *http.Request) {
128+
assert.Fail(t, "should not be called")
129+
}
130+
handler := api.httpCommunityAuthenticatedRequestHandler(upstream)
131+
handler.ServeHTTP(w, r)
132+
assert.Equal(t, w.Code, http.StatusUnauthorized)
133+
test.AssertApiError(t, w, "M_UNAUTHORIZED", "Not allowed")
134+
}
135+
136+
func TestCommunityAuthenticatedApiWrongAuth(t *testing.T) {
137+
t.Parallel()
138+
139+
api := makeApi(t)
140+
141+
w := httptest.NewRecorder()
142+
r := httptest.NewRequest(http.MethodGet, "/example", nil)
143+
r.Header.Set("Authorization", "Bearer WRONG_TOKEN")
144+
upstream := func(a *Api, c *storage.StoredCommunity, w http.ResponseWriter, r *http.Request) {
145+
assert.Fail(t, "should not be called")
146+
}
147+
handler := api.httpCommunityAuthenticatedRequestHandler(upstream)
148+
handler.ServeHTTP(w, r)
149+
assert.Equal(t, w.Code, http.StatusUnauthorized)
150+
test.AssertApiError(t, w, "M_UNAUTHORIZED", "Not allowed")
151+
}
152+
153+
func createCommunityWithAccessToken(t *testing.T, api *Api) *storage.StoredCommunity {
154+
serverCommunity, err := api.storage.CreateCommunity(context.Background(), "Test Community")
155+
assert.NoError(t, err)
156+
assert.NotNil(t, serverCommunity)
157+
serverCommunity.ApiAccessToken = internal.Pointer("pst_TESTING_COMMUNITY")
158+
err = api.storage.UpsertCommunity(context.Background(), serverCommunity)
159+
assert.NoError(t, err)
160+
return serverCommunity
161+
}
162+
163+
func TestCommunityAuthenticatedApi(t *testing.T) {
164+
t.Parallel()
165+
166+
api := makeApi(t)
167+
serverCommunity := createCommunityWithAccessToken(t, api)
168+
169+
w := httptest.NewRecorder()
170+
r := httptest.NewRequest(http.MethodGet, "/example", nil)
171+
r.Header.Set("Authorization", "Bearer "+internal.Dereference(serverCommunity.ApiAccessToken))
172+
called := false
173+
upstream := func(a *Api, c *storage.StoredCommunity, w http.ResponseWriter, r *http.Request) {
174+
assert.Equal(t, serverCommunity, c)
175+
called = true
176+
w.WriteHeader(http.StatusOK)
177+
}
178+
handler := api.httpCommunityAuthenticatedRequestHandler(upstream)
179+
handler.ServeHTTP(w, r)
180+
assert.Equal(t, w.Code, http.StatusOK)
181+
assert.True(t, called)
182+
}

api/http_api_communities.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
package api
22

33
import (
4+
"crypto/rand"
5+
"fmt"
46
"net/http"
57
"strings"
68

79
"github.com/matrix-org/policyserv/config"
10+
"github.com/matrix-org/policyserv/internal"
811
"github.com/matrix-org/policyserv/metrics"
912
)
1013

@@ -124,3 +127,46 @@ func httpSetCommunityConfigApi(api *Api, w http.ResponseWriter, r *http.Request)
124127
return
125128
}
126129
}
130+
131+
func httpRotateCommunityAccessTokenApi(api *Api, w http.ResponseWriter, r *http.Request) {
132+
metrics.RecordHttpRequest(r.Method, "httpRotateCommunityAccessTokenApi")
133+
t := metrics.StartRequestTimer(r.Method, "httpRotateCommunityAccessTokenApi")
134+
defer t.ObserveDuration()
135+
136+
errs := newErrorResponder("httpRotateCommunityAccessTokenApi", w, r)
137+
138+
if r.Method != http.MethodPost {
139+
errs.text(http.StatusMethodNotAllowed, "M_UNRECOGNIZED", "Method not allowed")
140+
return
141+
}
142+
143+
id := r.PathValue("id")
144+
community, err := api.storage.GetCommunity(r.Context(), id)
145+
if err != nil {
146+
errs.err(http.StatusInternalServerError, "M_UNKNOWN", err)
147+
return
148+
}
149+
if community == nil {
150+
errs.text(http.StatusNotFound, "M_NOT_FOUND", "Community not found")
151+
return
152+
}
153+
154+
oldAccessToken := internal.Dereference(community.ApiAccessToken)
155+
156+
newAccessToken := fmt.Sprintf("pst_%s", rand.Text())
157+
community.ApiAccessToken = internal.Pointer(newAccessToken)
158+
err = api.storage.UpsertCommunity(r.Context(), community)
159+
if err != nil {
160+
errs.err(http.StatusInternalServerError, "M_UNKNOWN", err)
161+
return
162+
}
163+
164+
err = respondJson("httpRotateCommunityAccessTokenApi", r, w, map[string]string{
165+
"old_access_token": oldAccessToken,
166+
"new_access_token": newAccessToken,
167+
})
168+
if err != nil {
169+
errs.err(http.StatusInternalServerError, "M_UNKNOWN", err)
170+
return
171+
}
172+
}

api/http_api_communities_test.go

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"testing"
1010

1111
"github.com/matrix-org/policyserv/config"
12+
"github.com/matrix-org/policyserv/internal"
1213
"github.com/matrix-org/policyserv/storage"
1314
"github.com/matrix-org/policyserv/test"
1415
"github.com/stretchr/testify/assert"
@@ -75,13 +76,15 @@ func TestCreateCommunityCreate(t *testing.T) {
7576
assert.Equal(t, communityName, community.Name)
7677
assert.NotEmpty(t, community.CommunityId)
7778
assert.NotNil(t, community.Config)
79+
assert.Nil(t, community.ApiAccessToken) // no access token should be set on creation
7880

7981
// Ensure it was also stored
8082
fromDb, err := api.storage.GetCommunity(context.Background(), community.CommunityId)
8183
assert.NoError(t, err)
8284
assert.NotNil(t, fromDb)
8385
assert.Equal(t, communityName, fromDb.Name)
8486
assert.Equal(t, community.CommunityId, fromDb.CommunityId)
87+
assert.Nil(t, fromDb.ApiAccessToken) // no access token should be set on creation
8588

8689
// Note: we can't (currently) test that errors during database calls and HTTP responses are handled. A future test
8790
// case *should* cover this.
@@ -124,6 +127,12 @@ func TestGetCommunity(t *testing.T) {
124127
assert.NotEmpty(t, community.CommunityId)
125128
assert.Equal(t, name, community.Name)
126129

130+
// Set an access token for the community. This is to ensure we don't leak it through the request.
131+
community.ApiAccessToken = internal.Pointer("pst_TESTING")
132+
err = api.storage.UpsertCommunity(context.Background(), community)
133+
assert.NoError(t, err)
134+
community.ApiAccessToken = nil // so the assert.Equal() passes later
135+
127136
w := httptest.NewRecorder()
128137
r := httptest.NewRequest(http.MethodGet, "/api/v1/communities/"+community.CommunityId, nil)
129138
r.SetPathValue("id", community.CommunityId)
@@ -178,6 +187,12 @@ func TestSetCommunityConfig(t *testing.T) {
178187
assert.NotEmpty(t, community.CommunityId)
179188
assert.Equal(t, name, community.Name)
180189

190+
// Set an access token for the community. This is to ensure we don't leak it through the request.
191+
community.ApiAccessToken = internal.Pointer("pst_TESTING")
192+
err = api.storage.UpsertCommunity(context.Background(), community)
193+
assert.NoError(t, err)
194+
community.ApiAccessToken = nil // so the assert.Equal() passes later
195+
181196
cnf := &config.CommunityConfig{
182197
KeywordFilterKeywords: &[]string{"keyword1", "keyword2"},
183198
}
@@ -196,3 +211,74 @@ func TestSetCommunityConfig(t *testing.T) {
196211
// Note: we can't (currently) test that errors during database calls and HTTP responses are handled. A future test
197212
// case *should* cover this.
198213
}
214+
215+
func TestRotateCommunityAccessTokenWrongMethod(t *testing.T) {
216+
t.Parallel()
217+
218+
api := makeApi(t)
219+
220+
w := httptest.NewRecorder()
221+
r := httptest.NewRequest(http.MethodGet /*this should be POST*/, "/api/v1/communities/not_a_real_id/rotate_access_token", nil)
222+
httpSetCommunityConfigApi(api, w, r)
223+
assert.Equal(t, http.StatusMethodNotAllowed, w.Code)
224+
test.AssertApiError(t, w, "M_UNRECOGNIZED", "Method not allowed")
225+
}
226+
227+
func TestRotateCommunityAccessTokenNotFound(t *testing.T) {
228+
t.Parallel()
229+
230+
api := makeApi(t)
231+
232+
cnf := &config.CommunityConfig{
233+
KeywordFilterKeywords: &[]string{"keyword1", "keyword2"},
234+
}
235+
w := httptest.NewRecorder()
236+
r := httptest.NewRequest(http.MethodPost, "/api/v1/communities/not_a_real_id/rotate_access_token", test.MakeJsonBody(t, cnf))
237+
r.SetPathValue("id", "not_a_real_id")
238+
httpSetCommunityConfigApi(api, w, r)
239+
assert.Equal(t, http.StatusNotFound, w.Code)
240+
test.AssertApiError(t, w, "M_NOT_FOUND", "Community not found")
241+
}
242+
243+
func TestRotateCommunityAccessToken(t *testing.T) {
244+
t.Parallel()
245+
246+
api := makeApi(t)
247+
248+
name := "Test Community"
249+
community, err := api.storage.CreateCommunity(context.Background(), name)
250+
assert.NoError(t, err)
251+
assert.NotNil(t, community)
252+
assert.NotEmpty(t, community.CommunityId)
253+
assert.Equal(t, name, community.Name)
254+
assert.Nil(t, community.ApiAccessToken) // should be created without an access token
255+
256+
// First rotation should have an empty "old_access_token" and a non-empty "new_access_token"
257+
w := httptest.NewRecorder()
258+
r := httptest.NewRequest(http.MethodPost, "/api/v1/communities/"+community.CommunityId+"/rotate_access_token", nil)
259+
r.SetPathValue("id", community.CommunityId)
260+
httpRotateCommunityAccessTokenApi(api, w, r)
261+
assert.Equal(t, http.StatusOK, w.Code)
262+
fromRes := make(map[string]string)
263+
err = json.Unmarshal(w.Body.Bytes(), &fromRes)
264+
assert.NoError(t, err)
265+
assert.Empty(t, fromRes["old_access_token"])
266+
assert.NotEmpty(t, fromRes["new_access_token"])
267+
268+
// Verify the access token was persisted by the HTTP handler
269+
accessToken := fromRes["new_access_token"]
270+
fromDb, err := api.storage.GetCommunity(context.Background(), community.CommunityId)
271+
assert.NoError(t, err)
272+
assert.NotNil(t, fromDb)
273+
assert.Equal(t, accessToken, internal.Dereference(fromDb.ApiAccessToken))
274+
275+
// Second rotation should reference what was the last request's "new" access token and generate yet another new one
276+
w = httptest.NewRecorder()
277+
httpRotateCommunityAccessTokenApi(api, w, r)
278+
assert.Equal(t, http.StatusOK, w.Code)
279+
fromRes = make(map[string]string)
280+
err = json.Unmarshal(w.Body.Bytes(), &fromRes)
281+
assert.NoError(t, err)
282+
assert.Equal(t, accessToken, fromRes["old_access_token"]) // old token should match previous rotation
283+
assert.NotEmpty(t, fromRes["new_access_token"])
284+
}

0 commit comments

Comments
 (0)