Skip to content

Commit 9316a59

Browse files
committed
firewalldb+rules: thread contexts through KVStores methods
1 parent 64ab73a commit 9316a59

File tree

4 files changed

+88
-36
lines changed

4 files changed

+88
-36
lines changed

firewalldb/kvstores.go

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,15 @@ type KVStores interface {
6161
// error, the transaction is rolled back. If the rollback fails, the
6262
// original error returned by f is still returned. If the commit fails,
6363
// the commit error is returned.
64-
Update(f func(tx KVStoreTx) error) error
64+
Update(ctx context.Context, f func(ctx context.Context,
65+
tx KVStoreTx) error) error
6566

6667
// View opens a database read transaction and executes the function f
6768
// with the transaction passed as a parameter. After f exits, the
6869
// transaction is rolled back. If f errors, its error is returned, not a
6970
// rollback error (if any occur).
70-
View(f func(tx KVStoreTx) error) error
71+
View(ctx context.Context, f func(ctx context.Context,
72+
tx KVStoreTx) error) error
7173
}
7274

7375
// KVStoreTx represents a database transaction that can be used for both read
@@ -158,7 +160,9 @@ func (s *kvStores) beginTx(writable bool) (*kvStoreTx, error) {
158160
// returned.
159161
//
160162
// NOTE: this is part of the KVStores interface.
161-
func (s *kvStores) Update(f func(tx KVStoreTx) error) error {
163+
func (s *kvStores) Update(ctx context.Context, f func(ctx context.Context,
164+
tx KVStoreTx) error) error {
165+
162166
tx, err := s.beginTx(true)
163167
if err != nil {
164168
return err
@@ -171,7 +175,7 @@ func (s *kvStores) Update(f func(tx KVStoreTx) error) error {
171175
}
172176
}()
173177

174-
err = f(tx)
178+
err = f(ctx, tx)
175179
if err != nil {
176180
// Want to return the original error, not a rollback error if
177181
// any occur.
@@ -188,7 +192,9 @@ func (s *kvStores) Update(f func(tx KVStoreTx) error) error {
188192
// occur).
189193
//
190194
// NOTE: this is part of the KVStores interface.
191-
func (s *kvStores) View(f func(tx KVStoreTx) error) error {
195+
func (s *kvStores) View(ctx context.Context, f func(ctx context.Context,
196+
tx KVStoreTx) error) error {
197+
192198
tx, err := s.beginTx(false)
193199
if err != nil {
194200
return err
@@ -201,7 +207,7 @@ func (s *kvStores) View(f func(tx KVStoreTx) error) error {
201207
}
202208
}()
203209

204-
err = f(tx)
210+
err = f(ctx, tx)
205211
rollbackErr := tx.boltTx.Rollback()
206212
if err != nil {
207213
return err

firewalldb/kvstores_test.go

Lines changed: 59 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func TestKVStoreTxs(t *testing.T) {
2828

2929
// Test that if an action fails midway through the transaction, then
3030
// it is rolled back.
31-
err = store.Update(func(tx KVStoreTx) error {
31+
err = store.Update(ctx, func(ctx context.Context, tx KVStoreTx) error {
3232
err := tx.Global().Set(ctx, "test", []byte{1})
3333
if err != nil {
3434
return err
@@ -46,7 +46,7 @@ func TestKVStoreTxs(t *testing.T) {
4646
require.Error(t, err)
4747

4848
var v []byte
49-
err = store.View(func(tx KVStoreTx) error {
49+
err = store.View(ctx, func(ctx context.Context, tx KVStoreTx) error {
5050
b, err := tx.Global().Get(ctx, "test")
5151
if err != nil {
5252
return err
@@ -94,7 +94,7 @@ func testTempAndPermStores(t *testing.T, featureSpecificStore bool) {
9494

9595
store := db.GetKVStores("test-rule", [4]byte{1, 1, 1, 1}, featureName)
9696

97-
err = store.Update(func(tx KVStoreTx) error {
97+
err = store.Update(ctx, func(ctx context.Context, tx KVStoreTx) error {
9898
// Set an item in the temp store.
9999
err := tx.LocalTemp().Set(ctx, "test", []byte{4, 3, 2})
100100
if err != nil {
@@ -112,7 +112,7 @@ func testTempAndPermStores(t *testing.T, featureSpecificStore bool) {
112112
v1 []byte
113113
v2 []byte
114114
)
115-
err = store.View(func(tx KVStoreTx) error {
115+
err = store.View(ctx, func(ctx context.Context, tx KVStoreTx) error {
116116
b, err := tx.LocalTemp().Get(ctx, "test")
117117
if err != nil {
118118
return err
@@ -144,7 +144,7 @@ func testTempAndPermStores(t *testing.T, featureSpecificStore bool) {
144144

145145
// The temp store should no longer have the stored value but the perm
146146
// store should .
147-
err = store.View(func(tx KVStoreTx) error {
147+
err = store.View(ctx, func(ctx context.Context, tx KVStoreTx) error {
148148
b, err := tx.LocalTemp().Get(ctx, "test")
149149
if err != nil {
150150
return err
@@ -188,29 +188,37 @@ func TestKVStoreNameSpaces(t *testing.T) {
188188
rulesDB3 := db.GetKVStores("test-rule", groupID2, "re-balance")
189189

190190
// Test that the three ruleDBs share the same global space.
191-
err = rulesDB1.Update(func(tx KVStoreTx) error {
191+
err = rulesDB1.Update(ctx, func(ctx context.Context,
192+
tx KVStoreTx) error {
193+
192194
return tx.Global().Set(
193195
ctx, "test-global", []byte("global thing!"),
194196
)
195197
})
196198
require.NoError(t, err)
197199

198-
err = rulesDB2.Update(func(tx KVStoreTx) error {
200+
err = rulesDB2.Update(ctx, func(ctx context.Context,
201+
tx KVStoreTx) error {
202+
199203
return tx.Global().Set(
200204
ctx, "test-global", []byte("different global thing!"),
201205
)
202206
})
203207
require.NoError(t, err)
204208

205-
err = rulesDB3.Update(func(tx KVStoreTx) error {
209+
err = rulesDB3.Update(ctx, func(ctx context.Context,
210+
tx KVStoreTx) error {
211+
206212
return tx.Global().Set(
207213
ctx, "test-global", []byte("yet another global thing"),
208214
)
209215
})
210216
require.NoError(t, err)
211217

212218
var v []byte
213-
err = rulesDB1.View(func(tx KVStoreTx) error {
219+
err = rulesDB1.View(ctx, func(ctx context.Context,
220+
tx KVStoreTx) error {
221+
214222
b, err := tx.Global().Get(ctx, "test-global")
215223
if err != nil {
216224
return err
@@ -221,7 +229,9 @@ func TestKVStoreNameSpaces(t *testing.T) {
221229
require.NoError(t, err)
222230
require.True(t, bytes.Equal(v, []byte("yet another global thing")))
223231

224-
err = rulesDB2.View(func(tx KVStoreTx) error {
232+
err = rulesDB2.View(ctx, func(ctx context.Context,
233+
tx KVStoreTx) error {
234+
225235
b, err := tx.Global().Get(ctx, "test-global")
226236
if err != nil {
227237
return err
@@ -232,7 +242,9 @@ func TestKVStoreNameSpaces(t *testing.T) {
232242
require.NoError(t, err)
233243
require.True(t, bytes.Equal(v, []byte("yet another global thing")))
234244

235-
err = rulesDB3.View(func(tx KVStoreTx) error {
245+
err = rulesDB3.View(ctx, func(ctx context.Context,
246+
tx KVStoreTx) error {
247+
236248
b, err := tx.Global().Get(ctx, "test-global")
237249
if err != nil {
238250
return err
@@ -244,22 +256,30 @@ func TestKVStoreNameSpaces(t *testing.T) {
244256
require.True(t, bytes.Equal(v, []byte("yet another global thing")))
245257

246258
// Test that the feature space is not shared by any of the dbs.
247-
err = rulesDB1.Update(func(tx KVStoreTx) error {
259+
err = rulesDB1.Update(ctx, func(ctx context.Context,
260+
tx KVStoreTx) error {
261+
248262
return tx.Local().Set(ctx, "count", []byte("1"))
249263
})
250264
require.NoError(t, err)
251265

252-
err = rulesDB2.Update(func(tx KVStoreTx) error {
266+
err = rulesDB2.Update(ctx, func(ctx context.Context,
267+
tx KVStoreTx) error {
268+
253269
return tx.Local().Set(ctx, "count", []byte("2"))
254270
})
255271
require.NoError(t, err)
256272

257-
err = rulesDB3.Update(func(tx KVStoreTx) error {
273+
err = rulesDB3.Update(ctx, func(ctx context.Context,
274+
tx KVStoreTx) error {
275+
258276
return tx.Local().Set(ctx, "count", []byte("3"))
259277
})
260278
require.NoError(t, err)
261279

262-
err = rulesDB1.View(func(tx KVStoreTx) error {
280+
err = rulesDB1.View(ctx, func(ctx context.Context,
281+
tx KVStoreTx) error {
282+
263283
b, err := tx.Local().Get(ctx, "count")
264284
if err != nil {
265285
return err
@@ -270,7 +290,9 @@ func TestKVStoreNameSpaces(t *testing.T) {
270290
require.NoError(t, err)
271291
require.True(t, bytes.Equal(v, []byte("1")))
272292

273-
err = rulesDB2.View(func(tx KVStoreTx) error {
293+
err = rulesDB2.View(ctx, func(ctx context.Context,
294+
tx KVStoreTx) error {
295+
274296
b, err := tx.Local().Get(ctx, "count")
275297
if err != nil {
276298
return err
@@ -281,7 +303,9 @@ func TestKVStoreNameSpaces(t *testing.T) {
281303
require.NoError(t, err)
282304
require.True(t, bytes.Equal(v, []byte("2")))
283305

284-
err = rulesDB3.View(func(tx KVStoreTx) error {
306+
err = rulesDB3.View(ctx, func(ctx context.Context,
307+
tx KVStoreTx) error {
308+
285309
b, err := tx.Local().Get(ctx, "count")
286310
if err != nil {
287311
return err
@@ -299,22 +323,30 @@ func TestKVStoreNameSpaces(t *testing.T) {
299323
rulesDB2 = db.GetKVStores("test-rule", groupID1, "")
300324
rulesDB3 = db.GetKVStores("test-rule", groupID2, "")
301325

302-
err = rulesDB1.Update(func(tx KVStoreTx) error {
326+
err = rulesDB1.Update(ctx, func(ctx context.Context,
327+
tx KVStoreTx) error {
328+
303329
return tx.Local().Set(ctx, "test", []byte("thing 1"))
304330
})
305331
require.NoError(t, err)
306332

307-
err = rulesDB2.Update(func(tx KVStoreTx) error {
333+
err = rulesDB2.Update(ctx, func(ctx context.Context,
334+
tx KVStoreTx) error {
335+
308336
return tx.Local().Set(ctx, "test", []byte("thing 2"))
309337
})
310338
require.NoError(t, err)
311339

312-
err = rulesDB3.Update(func(tx KVStoreTx) error {
340+
err = rulesDB3.Update(ctx, func(ctx context.Context,
341+
tx KVStoreTx) error {
342+
313343
return tx.Local().Set(ctx, "test", []byte("thing 3"))
314344
})
315345
require.NoError(t, err)
316346

317-
err = rulesDB1.View(func(tx KVStoreTx) error {
347+
err = rulesDB1.View(ctx, func(ctx context.Context,
348+
tx KVStoreTx) error {
349+
318350
b, err := tx.Local().Get(ctx, "test")
319351
if err != nil {
320352
return err
@@ -325,7 +357,9 @@ func TestKVStoreNameSpaces(t *testing.T) {
325357
require.NoError(t, err)
326358
require.True(t, bytes.Equal(v, []byte("thing 2")))
327359

328-
err = rulesDB2.View(func(tx KVStoreTx) error {
360+
err = rulesDB2.View(ctx, func(ctx context.Context,
361+
tx KVStoreTx) error {
362+
329363
b, err := tx.Local().Get(ctx, "test")
330364
if err != nil {
331365
return err
@@ -336,7 +370,9 @@ func TestKVStoreNameSpaces(t *testing.T) {
336370
require.NoError(t, err)
337371
require.True(t, bytes.Equal(v, []byte("thing 2")))
338372

339-
err = rulesDB3.View(func(tx KVStoreTx) error {
373+
err = rulesDB3.View(ctx, func(ctx context.Context,
374+
tx KVStoreTx) error {
375+
340376
b, err := tx.Local().Get(ctx, "test")
341377
if err != nil {
342378
return err

rules/mock.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,16 @@ type mockKVStores struct {
4747
tx *mockKVStoresTX
4848
}
4949

50-
func (m *mockKVStores) Update(f func(tx firewalldb.KVStoreTx) error) error {
51-
return f(m.tx)
50+
func (m *mockKVStores) Update(ctx context.Context, f func(ctx context.Context,
51+
tx firewalldb.KVStoreTx) error) error {
52+
53+
return f(ctx, m.tx)
5254
}
5355

54-
func (m *mockKVStores) View(f func(tx firewalldb.KVStoreTx) error) error {
55-
return f(m.tx)
56+
func (m *mockKVStores) View(ctx context.Context, f func(ctx context.Context,
57+
tx firewalldb.KVStoreTx) error) error {
58+
59+
return f(ctx, m.tx)
5660
}
5761

5862
var _ firewalldb.KVStores = (*mockKVStores)(nil)

rules/onchain_budget.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,9 @@ func (o *OnChainBudgetEnforcer) handleBatchOpenChannelRequest(
531531
func (o *OnChainBudgetEnforcer) handlePendingPayment(ctx context.Context,
532532
request *onChainAction, reqID string) error {
533533

534-
return o.GetStores().Update(func(tx firewalldb.KVStoreTx) error {
534+
return o.GetStores().Update(ctx, func(ctx context.Context,
535+
tx firewalldb.KVStoreTx) error {
536+
535537
// First, we fetch the current state of the budget.
536538
spent, pending, err := o.getBudgetState(ctx, tx)
537539
if err != nil {
@@ -586,7 +588,9 @@ type onChainAction struct {
586588
func (o *OnChainBudgetEnforcer) cancelPendingPayment(
587589
ctx context.Context) error {
588590

589-
return o.GetStores().Update(func(tx firewalldb.KVStoreTx) error {
591+
return o.GetStores().Update(ctx, func(ctx context.Context,
592+
tx firewalldb.KVStoreTx) error {
593+
590594
// First, we get our current budget state.
591595
_, pending, err := o.getBudgetState(ctx, tx)
592596
if err != nil {
@@ -643,7 +647,9 @@ func (o *OnChainBudgetEnforcer) cancelPendingPayment(
643647
func (o *OnChainBudgetEnforcer) handlePaymentConfirmed(
644648
ctx context.Context) error {
645649

646-
return o.GetStores().Update(func(tx firewalldb.KVStoreTx) error {
650+
return o.GetStores().Update(ctx, func(ctx context.Context,
651+
tx firewalldb.KVStoreTx) error {
652+
647653
// First, we get our current budget state.
648654
complete, pending, err := o.getBudgetState(ctx, tx)
649655
if err != nil {

0 commit comments

Comments
 (0)