Skip to content
This repository was archived by the owner on Nov 17, 2025. It is now read-only.

Commit 0d22cf1

Browse files
authored
Merge pull request #432 from matrix-org/kegan/device-list-updates
Ensure device list updates are robust to race conditions and network failures
2 parents 150d9d6 + f564f2d commit 0d22cf1

File tree

3 files changed

+194
-75
lines changed

3 files changed

+194
-75
lines changed

state/device_data_table.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ func NewDeviceDataTable(db *sqlx.DB) *DeviceDataTable {
4646
func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *internal.DeviceData, err error) {
4747
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
4848
var row DeviceDataRow
49-
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2`, userID, deviceID)
49+
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, userID, deviceID)
5050
if err != nil {
5151
if err == sql.ErrNoRows {
5252
// if there is no device data for this user, it's not an error.
@@ -70,6 +70,9 @@ func (t *DeviceDataTable) Select(userID, deviceID string, swap bool) (result *in
7070
if !swap {
7171
return nil // don't swap
7272
}
73+
// the caller will only look at sent, so make sure what is new is now in sent
74+
result.DeviceLists.Sent = result.DeviceLists.New
75+
7376
// swap over the fields
7477
writeBack := *result
7578
writeBack.DeviceLists.Sent = result.DeviceLists.New
@@ -104,7 +107,7 @@ func (t *DeviceDataTable) Upsert(dd *internal.DeviceData) (err error) {
104107
err = sqlutil.WithTransaction(t.db, func(txn *sqlx.Tx) error {
105108
// select what already exists
106109
var row DeviceDataRow
107-
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2`, dd.UserID, dd.DeviceID)
110+
err = txn.Get(&row, `SELECT data FROM syncv3_device_data WHERE user_id=$1 AND device_id=$2 FOR UPDATE`, dd.UserID, dd.DeviceID)
108111
if err != nil && err != sql.ErrNoRows {
109112
return err
110113
}

state/device_data_table_test.go

Lines changed: 147 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,20 @@ func assertDeviceData(t *testing.T, g, w internal.DeviceData) {
2222
assertVal(t, "FallbackKeyTypes", g.FallbackKeyTypes, w.FallbackKeyTypes)
2323
assertVal(t, "OTKCounts", g.OTKCounts, w.OTKCounts)
2424
assertVal(t, "ChangedBits", g.ChangedBits, w.ChangedBits)
25-
assertVal(t, "DeviceLists", g.DeviceLists, w.DeviceLists)
25+
if w.DeviceLists.Sent != nil {
26+
assertVal(t, "DeviceLists.Sent", g.DeviceLists.Sent, w.DeviceLists.Sent)
27+
}
2628
}
2729

28-
func TestDeviceDataTableSwaps(t *testing.T) {
30+
// Tests OTKCounts and FallbackKeyTypes behaviour
31+
func TestDeviceDataTableOTKCountAndFallbackKeyTypes(t *testing.T) {
2932
db, close := connectToDB(t)
3033
defer close()
3134
table := NewDeviceDataTable(db)
32-
userID := "@bob"
35+
userID := "@TestDeviceDataTableOTKCountAndFallbackKeyTypes"
3336
deviceID := "BOB"
3437

35-
// test accumulating deltas
38+
// these are individual updates from Synapse from /sync v2
3639
deltas := []internal.DeviceData{
3740
{
3841
UserID: userID,
@@ -46,9 +49,6 @@ func TestDeviceDataTableSwaps(t *testing.T) {
4649
UserID: userID,
4750
DeviceID: deviceID,
4851
FallbackKeyTypes: []string{"foobar"},
49-
DeviceLists: internal.DeviceLists{
50-
New: internal.ToDeviceListChangesMap([]string{"alice"}, nil),
51-
},
5252
},
5353
{
5454
UserID: userID,
@@ -60,85 +60,157 @@ func TestDeviceDataTableSwaps(t *testing.T) {
6060
{
6161
UserID: userID,
6262
DeviceID: deviceID,
63-
DeviceLists: internal.DeviceLists{
64-
New: internal.ToDeviceListChangesMap([]string{"💣"}, nil),
65-
},
6663
},
6764
}
65+
66+
// apply them
6867
for _, dd := range deltas {
6968
err := table.Upsert(&dd)
7069
assertNoError(t, err)
7170
}
7271

72+
// read them without swap, it should have replaced them correctly.
73+
// Because sync v2 sends the complete OTK count and complete fallback key types
74+
// every time, we always use the latest values. Because we aren't swapping, repeated
75+
// reads produce the same result.
76+
for i := 0; i < 3; i++ {
77+
got, err := table.Select(userID, deviceID, false)
78+
mustNotError(t, err)
79+
want := internal.DeviceData{
80+
UserID: userID,
81+
DeviceID: deviceID,
82+
OTKCounts: map[string]int{
83+
"foo": 99,
84+
},
85+
FallbackKeyTypes: []string{"foobar"},
86+
}
87+
want.SetFallbackKeysChanged()
88+
want.SetOTKCountChanged()
89+
assertDeviceData(t, *got, want)
90+
}
91+
// now we swap the data. This still returns the same values, but the changed bits are no longer set
92+
// on subsequent reads.
93+
got, err := table.Select(userID, deviceID, true)
94+
mustNotError(t, err)
7395
want := internal.DeviceData{
7496
UserID: userID,
7597
DeviceID: deviceID,
7698
OTKCounts: map[string]int{
7799
"foo": 99,
78100
},
79101
FallbackKeyTypes: []string{"foobar"},
80-
DeviceLists: internal.DeviceLists{
81-
New: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
82-
Sent: map[string]int{},
83-
},
84102
}
85103
want.SetFallbackKeysChanged()
86104
want.SetOTKCountChanged()
87-
// check we can read-only select
105+
assertDeviceData(t, *got, want)
106+
107+
// subsequent read
108+
got, err = table.Select(userID, deviceID, false)
109+
mustNotError(t, err)
110+
want = internal.DeviceData{
111+
UserID: userID,
112+
DeviceID: deviceID,
113+
OTKCounts: map[string]int{
114+
"foo": 99,
115+
},
116+
FallbackKeyTypes: []string{"foobar"},
117+
}
118+
assertDeviceData(t, *got, want)
119+
}
120+
121+
// Tests the DeviceLists field
122+
func TestDeviceDataTableDeviceList(t *testing.T) {
123+
db, close := connectToDB(t)
124+
defer close()
125+
table := NewDeviceDataTable(db)
126+
userID := "@TestDeviceDataTableDeviceList"
127+
deviceID := "BOB"
128+
129+
// these are individual updates from Synapse from /sync v2
130+
deltas := []internal.DeviceData{
131+
{
132+
UserID: userID,
133+
DeviceID: deviceID,
134+
DeviceLists: internal.DeviceLists{
135+
New: internal.ToDeviceListChangesMap([]string{"alice"}, nil),
136+
},
137+
},
138+
{
139+
UserID: userID,
140+
DeviceID: deviceID,
141+
DeviceLists: internal.DeviceLists{
142+
New: internal.ToDeviceListChangesMap([]string{"💣"}, nil),
143+
},
144+
},
145+
}
146+
// apply them
147+
for _, dd := range deltas {
148+
err := table.Upsert(&dd)
149+
assertNoError(t, err)
150+
}
151+
152+
// check we can read-only select. This doesn't modify any fields.
88153
for i := 0; i < 3; i++ {
89154
got, err := table.Select(userID, deviceID, false)
90155
assertNoError(t, err)
91-
assertDeviceData(t, *got, want)
156+
assertDeviceData(t, *got, internal.DeviceData{
157+
UserID: userID,
158+
DeviceID: deviceID,
159+
DeviceLists: internal.DeviceLists{
160+
Sent: internal.MapStringInt{}, // until we "swap" we don't consume the New entries
161+
},
162+
})
92163
}
93-
// now swap-er-roo, at this point we still expect the "old" data,
94-
// as it is the first time we swap
164+
// now swap-er-roo, which shifts everything from New into Sent.
95165
got, err := table.Select(userID, deviceID, true)
96166
assertNoError(t, err)
97-
assertDeviceData(t, *got, want)
98-
99-
// changed bits were reset when we swapped
100-
want2 := want
101-
want2.DeviceLists = internal.DeviceLists{
102-
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
103-
New: map[string]int{},
104-
}
105-
want2.ChangedBits = 0
106-
want.ChangedBits = 0
167+
assertDeviceData(t, *got, internal.DeviceData{
168+
UserID: userID,
169+
DeviceID: deviceID,
170+
DeviceLists: internal.DeviceLists{
171+
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
172+
},
173+
})
107174

108175
// this is permanent, read-only views show this too.
109-
// Since we have swapped previously, we now expect New to be empty
110-
// and Sent to be set. Swap again to clear Sent.
111-
got, err = table.Select(userID, deviceID, true)
176+
got, err = table.Select(userID, deviceID, false)
112177
assertNoError(t, err)
113-
assertDeviceData(t, *got, want2)
178+
assertDeviceData(t, *got, internal.DeviceData{
179+
UserID: userID,
180+
DeviceID: deviceID,
181+
DeviceLists: internal.DeviceLists{
182+
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
183+
},
184+
})
114185

115186
// We now expect empty DeviceLists, as we swapped twice.
116-
got, err = table.Select(userID, deviceID, false)
187+
got, err = table.Select(userID, deviceID, true)
117188
assertNoError(t, err)
118-
want3 := want2
119-
want3.DeviceLists = internal.DeviceLists{
120-
Sent: map[string]int{},
121-
New: map[string]int{},
122-
}
123-
assertDeviceData(t, *got, want3)
189+
assertDeviceData(t, *got, internal.DeviceData{
190+
UserID: userID,
191+
DeviceID: deviceID,
192+
DeviceLists: internal.DeviceLists{
193+
Sent: internal.MapStringInt{},
194+
},
195+
})
124196

125197
// get back the original state
126-
//err = table.DeleteDevice(userID, deviceID)
127198
assertNoError(t, err)
128199
for _, dd := range deltas {
129200
err = table.Upsert(&dd)
130201
assertNoError(t, err)
131202
}
132-
want.SetFallbackKeysChanged()
133-
want.SetOTKCountChanged()
134-
got, err = table.Select(userID, deviceID, false)
135-
assertNoError(t, err)
136-
assertDeviceData(t, *got, want)
137-
138-
// swap once then add once so both sent and new are populated
139-
// Moves Alice and Bob to Sent
140-
_, err = table.Select(userID, deviceID, true)
203+
// Move original state to Sent by swapping
204+
got, err = table.Select(userID, deviceID, true)
141205
assertNoError(t, err)
206+
assertDeviceData(t, *got, internal.DeviceData{
207+
UserID: userID,
208+
DeviceID: deviceID,
209+
DeviceLists: internal.DeviceLists{
210+
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
211+
},
212+
})
213+
// Add new entries to New before acknowledging Sent
142214
err = table.Upsert(&internal.DeviceData{
143215
UserID: userID,
144216
DeviceID: deviceID,
@@ -148,20 +220,18 @@ func TestDeviceDataTableSwaps(t *testing.T) {
148220
})
149221
assertNoError(t, err)
150222

151-
want.ChangedBits = 0
152-
153-
want4 := want
154-
want4.DeviceLists = internal.DeviceLists{
155-
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
156-
New: internal.ToDeviceListChangesMap([]string{"💣"}, []string{"charlie"}),
157-
}
158-
// Without swapping, we expect Alice and Bob in Sent, and Bob and Charlie in New
223+
// Reading without swapping does not move New->Sent, so returns the previous value
159224
got, err = table.Select(userID, deviceID, false)
160225
assertNoError(t, err)
161-
assertDeviceData(t, *got, want4)
226+
assertDeviceData(t, *got, internal.DeviceData{
227+
UserID: userID,
228+
DeviceID: deviceID,
229+
DeviceLists: internal.DeviceLists{
230+
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
231+
},
232+
})
162233

163-
// another append then consume
164-
// This results in dave to be added to New
234+
// Append even more items to New
165235
err = table.Upsert(&internal.DeviceData{
166236
UserID: userID,
167237
DeviceID: deviceID,
@@ -170,24 +240,28 @@ func TestDeviceDataTableSwaps(t *testing.T) {
170240
},
171241
})
172242
assertNoError(t, err)
243+
244+
// Now swap: all the combined items in New go into Sent
173245
got, err = table.Select(userID, deviceID, true)
174246
assertNoError(t, err)
175-
want5 := want4
176-
want5.DeviceLists = internal.DeviceLists{
177-
Sent: internal.ToDeviceListChangesMap([]string{"alice", "💣"}, nil),
178-
New: internal.ToDeviceListChangesMap([]string{"💣"}, []string{"charlie", "dave"}),
179-
}
180-
assertDeviceData(t, *got, want5)
247+
assertDeviceData(t, *got, internal.DeviceData{
248+
UserID: userID,
249+
DeviceID: deviceID,
250+
DeviceLists: internal.DeviceLists{
251+
Sent: internal.ToDeviceListChangesMap([]string{"💣", "dave"}, []string{"charlie", "dave"}),
252+
},
253+
})
181254

182-
// Swapping again clears New
255+
// Swapping again clears Sent out, and since nothing is in New we get an empty list
183256
got, err = table.Select(userID, deviceID, true)
184257
assertNoError(t, err)
185-
want5 = want4
186-
want5.DeviceLists = internal.DeviceLists{
187-
Sent: internal.ToDeviceListChangesMap([]string{"💣"}, []string{"charlie", "dave"}),
188-
New: map[string]int{},
189-
}
190-
assertDeviceData(t, *got, want5)
258+
assertDeviceData(t, *got, internal.DeviceData{
259+
UserID: userID,
260+
DeviceID: deviceID,
261+
DeviceLists: internal.DeviceLists{
262+
Sent: internal.MapStringInt{},
263+
},
264+
})
191265

192266
// delete everything, no data returned
193267
assertNoError(t, table.DeleteDevice(userID, deviceID))

tests-integration/extensions_test.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,48 @@ func TestExtensionE2EE(t *testing.T) {
193193
if time.Since(start) >= (500 * time.Millisecond) {
194194
t.Fatalf("sync request did not return immediately with OTK counts")
195195
}
196+
197+
// check that if we lose a device list update and restart from nothing, we see the same update
198+
v2.queueResponse(alice, sync2.SyncResponse{
199+
DeviceLists: struct {
200+
Changed []string `json:"changed,omitempty"`
201+
Left []string `json:"left,omitempty"`
202+
}{
203+
Changed: wantChanged,
204+
Left: wantLeft,
205+
},
206+
})
207+
v2.waitUntilEmpty(t, alice)
208+
res = v3.mustDoV3RequestWithPos(t, aliceToken, res.Pos, sync3.Request{
209+
Lists: map[string]sync3.RequestList{"a": {
210+
Ranges: sync3.SliceRanges{
211+
[2]int64{0, 10}, // doesn't matter
212+
},
213+
}},
214+
// enable the E2EE extension
215+
Extensions: extensions.Request{
216+
E2EE: &extensions.E2EERequest{
217+
Core: extensions.Core{Enabled: &boolTrue},
218+
},
219+
},
220+
})
221+
m.MatchResponse(t, res, m.MatchDeviceLists(wantChanged, wantLeft))
222+
// we actually lost this update: start again and we should see it.
223+
res = v3.mustDoV3Request(t, aliceToken, sync3.Request{
224+
Lists: map[string]sync3.RequestList{"a": {
225+
Ranges: sync3.SliceRanges{
226+
[2]int64{0, 10}, // doesn't matter
227+
},
228+
}},
229+
// enable the E2EE extension
230+
Extensions: extensions.Request{
231+
E2EE: &extensions.E2EERequest{
232+
Core: extensions.Core{Enabled: &boolTrue},
233+
},
234+
},
235+
})
236+
m.MatchResponse(t, res, m.MatchDeviceLists(wantChanged, wantLeft))
237+
196238
}
197239

198240
// Checks that to-device messages are passed from v2 to v3

0 commit comments

Comments
 (0)