Skip to content

Commit b3fced3

Browse files
committed
Performance and bug-fix backports from Harmony
This contains various backports: * Reduced allocations and performance improvements in state resolution * Reduced allocations and performance improvements in JSON handling * Event auth fixes, including correct error surfacing
1 parent dbd5f31 commit b3fced3

15 files changed

+180
-176
lines changed

authstate.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,10 @@ func checkAllowedByAuthEvents(
165165
event PDU, eventsByID map[string]PDU,
166166
missingAuth EventProvider, userIDForSender spec.UserIDForSender,
167167
) error {
168-
authEvents := NewAuthEvents(nil)
168+
authEvents, err := NewAuthEvents(nil)
169+
if err != nil {
170+
return err
171+
}
169172

170173
for _, ae := range event.AuthEventIDs() {
171174
retryEvent:
@@ -214,7 +217,7 @@ func checkAllowedByAuthEvents(
214217

215218
// If we made it this far then we've successfully got as many of the auth events as
216219
// as described by AuthEventIDs(). Check if they allow the event.
217-
if err := Allowed(event, &authEvents, userIDForSender); err != nil {
220+
if err := Allowed(event, authEvents, userIDForSender); err != nil {
218221
return fmt.Errorf(
219222
"gomatrixserverlib: event with ID %q is not allowed by its auth_events: %s",
220223
event.EventID(), err.Error(),
@@ -335,7 +338,10 @@ func CheckSendJoinResponse(
335338
}
336339

337340
eventsByID := map[string]PDU{}
338-
authEventProvider := NewAuthEvents(nil)
341+
authEventProvider, err := NewAuthEvents(nil)
342+
if err != nil {
343+
return nil, err
344+
}
339345

340346
// Since checkAllowedByAuthEvents needs to be able to look up any of the
341347
// auth events by ID only, we will build a map which contains references
@@ -369,7 +375,7 @@ func CheckSendJoinResponse(
369375
}
370376

371377
// Now check that the join event is valid against the supplied state.
372-
if err := Allowed(joinEvent, &authEventProvider, userIDForSender); err != nil {
378+
if err := Allowed(joinEvent, authEventProvider, userIDForSender); err != nil {
373379
return nil, fmt.Errorf(
374380
"gomatrixserverlib: event with ID %q is not allowed by the current room state: %w",
375381
joinEvent.EventID(), err,

backfill.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ func RequestBackfill(ctx context.Context, origin spec.ServerName, b BackfillRequ
9999
}
100100
}
101101

102-
return result, lastErr
102+
// Since we pulled in results from multiple servers we need to sort again...
103+
return ReverseTopologicalOrdering(result, TopologicalOrderByPrevEvents), lastErr
103104
}
104105

105106
/*

eventV1.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,7 @@ func (e *eventV1) SetUnsignedField(path string, value interface{}) error {
215215
eventJSON = CanonicalJSONAssumeValid(eventJSON)
216216

217217
res := gjson.GetBytes(eventJSON, "unsigned")
218-
unsigned := RawJSONFromResult(res, eventJSON)
219-
e.eventFields.Unsigned = unsigned
218+
e.eventFields.Unsigned = []byte(res.Raw)
220219

221220
e.eventJSON = eventJSON
222221

eventauth.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -301,15 +301,17 @@ func (a *AuthEvents) Clear() {
301301

302302
// NewAuthEvents returns an AuthEventProvider backed by the given events. New events can be added by
303303
// calling AddEvent().
304-
func NewAuthEvents(events []PDU) AuthEvents {
304+
func NewAuthEvents(events []PDU) (*AuthEvents, error) {
305305
a := AuthEvents{
306306
events: make(map[StateKeyTuple]PDU, len(events)),
307307
roomIDs: make(map[string]struct{}),
308308
}
309309
for _, e := range events {
310-
a.AddEvent(e) // nolint: errcheck
310+
if err := a.AddEvent(e); err != nil {
311+
return nil, err
312+
}
311313
}
312-
return a
314+
return &a, nil
313315
}
314316

315317
// A NotAllowed error is returned if an event does not pass the auth checks.

eventauth_test.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ func TestStateNeededForMessage(t *testing.T) {
104104
// Message events need the create event, the sender and the power_levels.
105105
testStateNeededForAuth(t, `[{
106106
"type": "m.room.message",
107-
"sender": "@u1:a",
107+
"sender": "@u1:a",
108108
"room_id": "!r1:a"
109109
}]`, &ProtoEvent{
110110
Type: "m.room.message",
@@ -139,7 +139,7 @@ func TestStateNeededForJoin(t *testing.T) {
139139
"type": "m.room.member",
140140
"state_key": "@u1:a",
141141
"sender": "@u1:a",
142-
"content": {"membership": "join"},
142+
"content": {"membership": "join"},
143143
"room_id": "!r1:a"
144144
}]`, &b, StateNeeded{
145145
Create: true,
@@ -163,7 +163,7 @@ func TestStateNeededForInvite(t *testing.T) {
163163
"type": "m.room.member",
164164
"state_key": "@u2:b",
165165
"sender": "@u1:a",
166-
"content": {"membership": "invite"},
166+
"content": {"membership": "invite"},
167167
"room_id": "!r1:a"
168168
}]`, &b, StateNeeded{
169169
Create: true,
@@ -199,7 +199,7 @@ func TestStateNeededForInvite3PID(t *testing.T) {
199199
"token": "my_token"
200200
}
201201
}
202-
},
202+
},
203203
"room_id": "!r1:a"
204204
}]`, &b, StateNeeded{
205205
Create: true,
@@ -1035,7 +1035,7 @@ func TestAuthEvents(t *testing.T) {
10351035
if err != nil {
10361036
t.Fatalf("TestAuthEvents: failed to create power_levels event: %s", err)
10371037
}
1038-
a := NewAuthEvents([]PDU{power})
1038+
a, _ := NewAuthEvents([]PDU{power})
10391039
var e PDU
10401040
if e, err = a.PowerLevels(); err != nil || e != power {
10411041
t.Errorf("TestAuthEvents: failed to get same power_levels event")
@@ -1685,15 +1685,15 @@ func TestMembershipBanned(t *testing.T) {
16851685
"state_key": "@u3:a",
16861686
"event_id": "$e4:a",
16871687
"content": {"membership": "ban"}
1688-
},
1688+
},
16891689
{
16901690
"type": "m.room.member",
16911691
"sender": "@u2:a",
16921692
"room_id": "!r1:a",
16931693
"state_key": "@u3:a",
16941694
"event_id": "$e4:a",
16951695
"content": {"membership": "ban"}
1696-
},
1696+
},
16971697
{
16981698
"type": "m.room.member",
16991699
"sender": "@u2:a",

handleinvite_test.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,19 @@ func (r *TestStateQuerier) GetAuthEvents(ctx context.Context, event PDU) (AuthEv
3838
return nil, fmt.Errorf("failed getting auth provider")
3939
}
4040

41-
eventProvider := AuthEvents{}
41+
eventProvider, _ := NewAuthEvents(nil)
4242
if r.createEvent != nil {
43-
eventProvider = NewAuthEvents([]PDU{r.createEvent})
43+
if err := eventProvider.AddEvent(r.createEvent); err != nil {
44+
return nil, err
45+
}
4446
if r.inviterMemberEvent != nil {
4547
err := eventProvider.AddEvent(r.inviterMemberEvent)
4648
if err != nil {
4749
return nil, err
4850
}
4951
}
5052
}
51-
return &eventProvider, nil
53+
return eventProvider, nil
5254
}
5355

5456
func (r *TestStateQuerier) GetState(ctx context.Context, roomID spec.RoomID, stateWanted []StateKeyTuple) ([]PDU, error) {

handlejoin.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,11 @@ func HandleMakeJoin(input HandleMakeJoinInput) (*HandleMakeJoinResponse, error)
116116
return nil, spec.InternalServerError{Err: fmt.Sprintf("expected join event from template builder. got: %s", event.Type())}
117117
}
118118

119-
provider := NewAuthEvents(state)
120-
if err = Allowed(event, &provider, input.UserIDQuerier); err != nil {
119+
provider, err := NewAuthEvents(state)
120+
if err != nil {
121+
return nil, spec.Forbidden(err.Error())
122+
}
123+
if err = Allowed(event, provider, input.UserIDQuerier); err != nil {
121124
return nil, spec.Forbidden(err.Error())
122125
}
123126

handlejoin_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,7 @@ func TestHandleMakeJoinNilContext(t *testing.T) {
632632
})
633633
}
634634

635+
//nolint:unparam
635636
func createMemberEventBuilder(roomVersion RoomVersion, sender string, roomID string, stateKey *string, content spec.RawJSON) *EventBuilder {
636637
return MustGetRoomVersion(roomVersion).NewEventBuilderFromProtoEvent(&ProtoEvent{
637638
SenderID: sender,

handleleave.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,11 @@ func HandleMakeLeave(input HandleMakeLeaveInput) (*HandleMakeLeaveResponse, erro
8181
return nil, spec.InternalServerError{Err: fmt.Sprintf("expected leave event from template builder. got: %s", event.Type())}
8282
}
8383

84-
provider := NewAuthEvents(stateEvents)
85-
if err := Allowed(event, &provider, input.UserIDQuerier); err != nil {
84+
provider, err := NewAuthEvents(stateEvents)
85+
if err != nil {
86+
return nil, spec.Forbidden(err.Error())
87+
}
88+
if err = Allowed(event, provider, input.UserIDQuerier); err != nil {
8689
return nil, spec.Forbidden(err.Error())
8790
}
8891

json.go

Lines changed: 24 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ package gomatrixserverlib
1818
import (
1919
"encoding/binary"
2020
"errors"
21-
"sort"
21+
"slices"
2222
"strings"
2323
"unicode/utf16"
2424
"unicode/utf8"
@@ -159,40 +159,33 @@ func CanonicalJSONAssumeValid(input []byte) []byte {
159159
// by codepoint. The input must be valid JSON.
160160
func SortJSON(input, output []byte) []byte {
161161
result := gjson.ParseBytes(input)
162-
163-
RawJSON := RawJSONFromResult(result, input)
164-
return sortJSONValue(result, RawJSON, output)
162+
return sortJSONValue(result, output)
165163
}
166164

167165
// sortJSONValue takes a gjson.Result and sorts it. inputJSON must be the
168166
// raw JSON bytes that gjson.Result points to.
169-
func sortJSONValue(input gjson.Result, inputJSON, output []byte) []byte {
167+
func sortJSONValue(input gjson.Result, output []byte) []byte {
170168
if input.IsArray() {
171-
return sortJSONArray(input, inputJSON, output)
169+
return sortJSONArray(input, output)
172170
}
173-
174171
if input.IsObject() {
175-
return sortJSONObject(input, inputJSON, output)
172+
return sortJSONObject(input, output)
176173
}
177-
178174
// If its neither an object nor an array then there is no sub structure
179175
// to sort, so just append the raw bytes.
180-
return append(output, inputJSON...)
176+
return append(output, input.Raw...)
181177
}
182178

183179
// sortJSONArray takes a gjson.Result and sorts it, assuming its an array.
184180
// inputJSON must be the raw JSON bytes that gjson.Result points to.
185-
func sortJSONArray(input gjson.Result, inputJSON, output []byte) []byte {
181+
func sortJSONArray(input gjson.Result, output []byte) []byte {
186182
sep := byte('[')
187183

188184
// Iterate over each value in the array and sort it.
189185
input.ForEach(func(_, value gjson.Result) bool {
190186
output = append(output, sep)
191187
sep = ','
192-
193-
RawJSON := RawJSONFromResult(value, inputJSON)
194-
output = sortJSONValue(value, RawJSON, output)
195-
188+
output = sortJSONValue(value, output)
196189
return true // keep iterating
197190
})
198191

@@ -209,29 +202,30 @@ func sortJSONArray(input gjson.Result, inputJSON, output []byte) []byte {
209202

210203
// sortJSONObject takes a gjson.Result and sorts it, assuming its an object.
211204
// inputJSON must be the raw JSON bytes that gjson.Result points to.
212-
func sortJSONObject(input gjson.Result, inputJSON, output []byte) []byte {
205+
func sortJSONObject(input gjson.Result, output []byte) []byte {
213206
type entry struct {
214-
key string // The parsed key string
215-
rawKey []byte // The raw, unparsed key JSON string
216-
value gjson.Result
207+
key string // The parsed key string
208+
value gjson.Result
217209
}
218210

219-
var entries []entry
211+
// Try to stay on the stack here if we can.
212+
var _entries [128]entry
213+
entries := _entries[:0]
220214

221215
// Iterate over each key/value pair and add it to a slice
222216
// that we can sort
223217
input.ForEach(func(key, value gjson.Result) bool {
224218
entries = append(entries, entry{
225-
key: key.String(),
226-
rawKey: RawJSONFromResult(key, inputJSON),
227-
value: value,
219+
key: key.String(),
220+
value: value,
228221
})
229222
return true // keep iterating
230223
})
231224

232-
// Sort the slice based on the *parsed* key
233-
sort.Slice(entries, func(a, b int) bool {
234-
return entries[a].key < entries[b].key
225+
// Using slices.SortFunc here instead of sort.Slice avoids
226+
// heap escapes due to reflection.
227+
slices.SortFunc(entries, func(a, b entry) int {
228+
return strings.Compare(a.key, b.key)
235229
})
236230

237231
sep := byte('{')
@@ -241,12 +235,10 @@ func sortJSONObject(input gjson.Result, inputJSON, output []byte) []byte {
241235
sep = ','
242236

243237
// Append the raw unparsed JSON key, *not* the parsed key
244-
output = append(output, entry.rawKey...)
245-
output = append(output, ':')
246-
247-
RawJSON := RawJSONFromResult(entry.value, inputJSON)
248-
249-
output = sortJSONValue(entry.value, RawJSON, output)
238+
output = append(output, '"')
239+
output = append(output, entry.key...)
240+
output = append(output, '"', ':')
241+
output = sortJSONValue(entry.value, output)
250242
}
251243
if sep == '{' {
252244
// If sep is still '{' then the object was empty and we never wrote the
@@ -375,10 +367,3 @@ func readHexDigits(input []byte) rune {
375367
hex |= hex >> 8
376368
return rune(hex & 0xFFFF)
377369
}
378-
379-
// RawJSONFromResult extracts the raw JSON bytes pointed to by result.
380-
// input must be the json bytes that were used to generate result
381-
// TODO: Why do we do this?
382-
func RawJSONFromResult(result gjson.Result, _ []byte) []byte {
383-
return []byte(result.Raw)
384-
}

0 commit comments

Comments
 (0)