Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
.*.swp
coverage.out
72 changes: 66 additions & 6 deletions event.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"errors"
"fmt"
"reflect"
"regexp"
"strings"
"time"

Expand Down Expand Up @@ -1111,14 +1112,19 @@ func (er EventReference) MarshalJSON() ([]byte, error) {
return json.Marshal(&tuple)
}

// SplitID splits a matrix ID into a local part and a server name.
func SplitID(sigil byte, id string) (local string, domain ServerName, err error) {
// IDs have the format: SIGIL LOCALPART ":" DOMAIN
const UserIDSigil = '@'
const RoomAliasSigil = '#'
const RoomIDSigil = '!'
const EventIDSigil = '$'
const GroupIDSigil = '+'
const UserPermanentKeySigil = '~'
const UserDelegatedKeySigil = '^'

var IsValidMXID = regexp.MustCompile(`^[a-z0-9\_\-\.\/\=]+$`).MatchString

func splitID(id string) (local string, domain ServerName, err error) {
// Split on the first ":" character since the domain can contain ":"
// characters.
if len(id) == 0 || id[0] != sigil {
return "", "", fmt.Errorf("gomatrixserverlib: invalid ID %q doesn't start with %q", id, sigil)
}
parts := strings.SplitN(id, ":", 2)
if len(parts) != 2 {
// The ID must have a ":" character.
Expand All @@ -1127,6 +1133,60 @@ func SplitID(sigil byte, id string) (local string, domain ServerName, err error)
return parts[0][1:], ServerName(parts[1]), nil
}

// Deprecated: Replaced with SplitIDWithSigil
// SplitID splits a matrix ID into a local part and a server name.
func SplitID(unusedSigil byte, id string) (local string, domain ServerName, err error) {
if len(id) == 0 {
return "", "", fmt.Errorf("gomatrixserverlib: invalid ID %q", id)
}
// IDs have the format: SIGIL LOCALPART ":" DOMAIN
// Split on the first ":" character since the domain can contain ":"
// characters.
sigil := id[0]
if unusedSigil != sigil {
return "", "", fmt.Errorf("gomatrixserverlib: invalid ID %q doesn't start with %q", id, unusedSigil)
}
return SplitIDWithSigil(id)
}

// SplitID splits a matrix ID into a local part and a server name.
func SplitIDWithSigil(id string) (local string, domain ServerName, err error) {
if len(id) == 0 {
return "", "", fmt.Errorf("gomatrixserverlib: invalid ID %q", id)
}
// IDs have the format: SIGIL LOCALPART ":" DOMAIN
sigil := id[0]
switch sigil {
case UserPermanentKeySigil:
{
version := id[1]
if version == '1' {
// UPK, the whole ID is the local portion
return id[2:], "", err

}
return "", "", fmt.Errorf("gomatrixserverlib: invalid UPK version %q", version)

}
case UserIDSigil:
{
local, domain, err = splitID(id)
if err == nil {
if IsValidMXID(local) {
return local, domain, nil
}
// The Local portion of the User must be only a valid characters.
return "", "", fmt.Errorf("gomatrixserverlib: invalid local ID %q", local)
}
return
}
case RoomAliasSigil, RoomIDSigil, EventIDSigil, GroupIDSigil:
return splitID(id)
default:
return "", "", fmt.Errorf("gomatrixserverlib: invalid sigil %q", sigil)
}
}

// fixNilSlices corrects cases where nil slices end up with "null" in the
// marshalled JSON because Go stupidly doesn't care about the type in this
// situation.
Expand Down
110 changes: 110 additions & 0 deletions event_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@
package gomatrixserverlib

import (
"encoding/base64"
"encoding/json"
"errors"
"reflect"
"testing"

"github.com/stretchr/testify/assert"
"golang.org/x/crypto/ed25519"
)

func benchmarkParse(b *testing.B, eventJSON string) {
Expand Down Expand Up @@ -180,3 +184,109 @@ func TestHeaderedEventToNewEventFromUntrustedJSON(t *testing.T) {
t.Fatal("expected an UnexpectedHeaderedEvent error but got:", err)
}
}

func TestSplitID(t *testing.T) {
t.Run("To short id",
func(t *testing.T) {
_, _, err := SplitID('@', "")
assert.EqualErrorf(t, err, "gomatrixserverlib: invalid ID \"\"", "To short id")
})
t.Run("Mismatch Sigil",
func(t *testing.T) {
_, _, err := SplitID('@', "#1234abcd:test")
assert.EqualErrorf(t, err, "gomatrixserverlib: invalid ID \"#1234abcd:test\" doesn't start with '@'", "Mismatch Sigil incorrect error")
})
}
func TestSplitIDWithSigil(t *testing.T) {
t.Run("Too short id",
func(t *testing.T) {
_, _, err := SplitIDWithSigil("")
assert.EqualErrorf(t, err, "gomatrixserverlib: invalid ID \"\"", "Too short id")
})
t.Run("Invalid Sigil",
func(t *testing.T) {
_, _, err := SplitIDWithSigil("%1234abcd:test")
assert.EqualErrorf(t, err, "gomatrixserverlib: invalid sigil '%'", "Invalid Sigil incorrect error")
})

t.Run("No ServerName",
func(t *testing.T) {
_, _, err := SplitIDWithSigil("@1234abcd_test")
assert.EqualErrorf(t, err, "gomatrixserverlib: invalid ID \"@1234abcd_test\" missing ':'", "No ServerName incorrect error")
})

t.Run("UserID",
func(t *testing.T) {
localpart, domain, err := SplitIDWithSigil("@1234abcd:test")
if err != nil {
t.Fatal(err)
}
assert.Equal(t, "1234abcd", localpart, "The localpart should parse")
assert.Equal(t, ServerName("test"), domain, "The domain should parse")
})
t.Run("UserID - Missing :",
func(t *testing.T) {
_, _, err := SplitIDWithSigil("@1234Abcdtest")
assert.EqualErrorf(t, err, "gomatrixserverlib: invalid ID \"@1234Abcdtest\" missing ':'", "No : in UserID")

})
t.Run("UserID - Invalid",
func(t *testing.T) {
_, _, err := SplitIDWithSigil("@1234Abcd:test")
assert.EqualErrorf(t, err, "gomatrixserverlib: invalid local ID \"1234Abcd\"", "Error should be: %v, got: %v", "gomatrixserverlib: invalid local ID \"1234Abcd\"", err)

})

t.Run("UserID - UPK",
func(t *testing.T) {
pubKey, _, err := ed25519.GenerateKey(nil)
if err != nil {
t.Fatal(err)
}
encodedKey := base64.URLEncoding.EncodeToString(pubKey)
localpart, domain, err := SplitIDWithSigil("~1" + encodedKey)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, encodedKey, localpart, "The localpart should parse")
assert.Equal(t, ServerName(""), domain, "The domain should parse")
})

t.Run("UserID - Unsupported UPK version",
func(t *testing.T) {
pubKey, _, err := ed25519.GenerateKey(nil)
if err != nil {
t.Fatal(err)
}
encodedKey := base64.URLEncoding.EncodeToString(pubKey)
_, _, err = SplitIDWithSigil("~2" + encodedKey)
assert.EqualErrorf(t, err, "gomatrixserverlib: invalid UPK version '2'", "Only version 1 supported at this time")
})

t.Run("GroupID",
func(t *testing.T) {
localpart, domain, err := SplitIDWithSigil("+group/=_-.123:my.domain")
if err != nil {
t.Fatal(err)
}
assert.Equal(t, "group/=_-.123", localpart, "The localpart should parse")
assert.Equal(t, ServerName("my.domain"), domain, "The domain should parse")
})
t.Run("GroupID - Missing :",
func(t *testing.T) {
_, _, err := SplitIDWithSigil("+group/=_-.123my.domain")
assert.EqualErrorf(t, err, "gomatrixserverlib: invalid ID \"+group/=_-.123my.domain\" missing ':'", "No : in UserID")

})

t.Run("RoomAlias",

func(t *testing.T) {
localpart, domain, err := SplitIDWithSigil("#channel:test")
if err != nil {
t.Fatal(err)
}
assert.Equal(t, "channel", localpart, "The localpart should parse")
assert.Equal(t, ServerName("test"), domain, "The domain should parse")
})
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ require (
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7
github.com/miekg/dns v1.1.25
github.com/sirupsen/logrus v1.4.2
github.com/stretchr/testify v1.4.0 // indirect
github.com/stretchr/testify v1.4.0
github.com/tidwall/gjson v1.12.1
github.com/tidwall/sjson v1.0.3
golang.org/x/crypto v0.0.0-20200221231518-2aa609cf4a9d
Expand Down