@@ -3,6 +3,7 @@ package tests
33import (
44 "context"
55 "encoding/json"
6+ "fmt"
67 "net/http"
78 "net/url"
89 "testing"
@@ -264,3 +265,126 @@ func TestBannedUserCannotSendJoin(t *testing.T) {
264265 membership := must .GetJSONFieldStr (t , stateResp , "membership" )
265266 must .EqualStr (t , membership , "ban" , "membership of charlie" )
266267}
268+
269+ // This test checks that we cannot submit anything via /v1/send_join except a join.
270+ func TestCannotSendNonJoinViaSendJoinV1 (t * testing.T ) {
271+ testValidationForSendMembershipEndpoint (t , "/_matrix/federation/v1/send_join" , "join" , nil )
272+ }
273+
274+ // This test checks that we cannot submit anything via /v2/send_join except a join.
275+ func TestCannotSendNonJoinViaSendJoinV2 (t * testing.T ) {
276+ testValidationForSendMembershipEndpoint (t , "/_matrix/federation/v2/send_join" , "join" , nil )
277+ }
278+
279+ // This test checks that we cannot submit anything via /v1/send_leave except a leave.
280+ func TestCannotSendNonLeaveViaSendLeaveV1 (t * testing.T ) {
281+ testValidationForSendMembershipEndpoint (t , "/_matrix/federation/v1/send_leave" , "leave" , nil )
282+ }
283+
284+ // This test checks that we cannot submit anything via /v2/send_leave except a leave.
285+ func TestCannotSendNonLeaveViaSendLeaveV2 (t * testing.T ) {
286+ testValidationForSendMembershipEndpoint (t , "/_matrix/federation/v2/send_leave" , "leave" , nil )
287+ }
288+
289+ // testValidationForSendMembershipEndpoint attempts to submit a range of events via the given endpoint
290+ // and checks that they are all rejected.
291+ func testValidationForSendMembershipEndpoint (t * testing.T , baseApiPath , expectedMembership string , createRoomOpts map [string ]interface {}) {
292+ if createRoomOpts == nil {
293+ createRoomOpts = make (map [string ]interface {})
294+ }
295+
296+ deployment := Deploy (t , b .BlueprintAlice )
297+ defer deployment .Destroy (t )
298+
299+ srv := federation .NewServer (t , deployment ,
300+ federation .HandleKeyRequests (),
301+ federation .HandleTransactionRequests (nil , nil ),
302+ )
303+ cancel := srv .Listen ()
304+ defer cancel ()
305+
306+ // alice creates a room, and charlie joins it
307+ alice := deployment .Client (t , "hs1" , "@alice:hs1" )
308+ roomId := alice .CreateRoom (t , createRoomOpts )
309+ charlie := srv .UserID ("charlie" )
310+ room := srv .MustJoinRoom (t , deployment , "hs1" , roomId , charlie )
311+
312+ // a helper function which makes a send_* request to the given path and checks
313+ // that it fails with a 400 error
314+ assertRequestFails := func (t * testing.T , event * gomatrixserverlib.Event ) {
315+ path := fmt .Sprintf ("%s/%s/%s" ,
316+ baseApiPath ,
317+ url .PathEscape (event .RoomID ()),
318+ url .PathEscape (event .EventID ()),
319+ )
320+ t .Logf ("PUT %s" , path )
321+ req := gomatrixserverlib .NewFederationRequest ("PUT" , "hs1" , path )
322+ if err := req .SetContent (event ); err != nil {
323+ t .Errorf ("req.SetContent: %v" , err )
324+ return
325+ }
326+
327+ var res map [string ]interface {}
328+ err := srv .SendFederationRequest (deployment , req , & res )
329+ if err == nil {
330+ t .Errorf ("send request returned 200" )
331+ return
332+ }
333+
334+ httpError , ok := err .(gomatrix.HTTPError )
335+ if ! ok {
336+ t .Errorf ("not an HTTPError: %v" , err )
337+ return
338+ }
339+
340+ t .Logf ("%s returned %d/%s" , baseApiPath , httpError .Code , string (httpError .Contents ))
341+ if httpError .Code != 400 {
342+ t .Errorf ("expected 400, got %d" , httpError .Code )
343+ }
344+ }
345+
346+ t .Run ("regular event" , func (t * testing.T ) {
347+ event := srv .MustCreateEvent (t , room , b.Event {
348+ Type : "m.room.message" ,
349+ Sender : charlie ,
350+ Content : map [string ]interface {}{"body" : "bzz" },
351+ })
352+ assertRequestFails (t , event )
353+ })
354+ t .Run ("non-state membership event" , func (t * testing.T ) {
355+ event := srv .MustCreateEvent (t , room , b.Event {
356+ Type : "m.room.member" ,
357+ Sender : charlie ,
358+ Content : map [string ]interface {}{"body" : "bzz" },
359+ })
360+ assertRequestFails (t , event )
361+ })
362+
363+ // try membership events of various types, other than that expected by
364+ // the endpoint
365+ for _ , membershipType := range []string {"join" , "leave" , "knock" , "invite" } {
366+ if membershipType == expectedMembership {
367+ continue
368+ }
369+ event := srv .MustCreateEvent (t , room , b.Event {
370+ Type : "m.room.member" ,
371+ Sender : charlie ,
372+ StateKey : & charlie ,
373+ Content : map [string ]interface {}{"membership" : membershipType },
374+ })
375+ t .Run (membershipType + " event" , func (t * testing.T ) {
376+ assertRequestFails (t , event )
377+ })
378+ }
379+
380+ // right sort of membership, but mismatched state_key
381+ t .Run ("event with mismatched state key" , func (t * testing.T ) {
382+ event := srv .MustCreateEvent (t , room , b.Event {
383+ Type : "m.room.member" ,
384+ Sender : charlie ,
385+ StateKey : b .Ptr (srv .UserID ("doris" )),
386+ Content : map [string ]interface {}{"membership" : expectedMembership },
387+ })
388+ assertRequestFails (t , event )
389+ })
390+ }
0 commit comments