Skip to content

Commit a592d8f

Browse files
committed
tapdb: add includeLeased parameter
1 parent 23f0d5b commit a592d8f

File tree

3 files changed

+72
-21
lines changed

3 files changed

+72
-21
lines changed

rpcserver.go

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -922,7 +922,7 @@ func (r *rpcServer) checkBalanceOverflow(ctx context.Context,
922922
case assetID != nil:
923923
// Retrieve the current asset balance.
924924
balances, err := r.cfg.AssetStore.QueryBalancesByAsset(
925-
ctx, assetID,
925+
ctx, assetID, true,
926926
)
927927
if err != nil {
928928
return fmt.Errorf("unable to query asset balance: %w",
@@ -938,7 +938,7 @@ func (r *rpcServer) checkBalanceOverflow(ctx context.Context,
938938
case groupPubKey != nil:
939939
// Retrieve the current balance of the group.
940940
balances, err := r.cfg.AssetStore.QueryAssetBalancesByGroup(
941-
ctx, groupPubKey,
941+
ctx, groupPubKey, true,
942942
)
943943
if err != nil {
944944
return fmt.Errorf("unable to query group balance: %w",
@@ -1106,9 +1106,12 @@ func (r *rpcServer) MarshalChainAsset(ctx context.Context, a *asset.ChainAsset,
11061106
}
11071107

11081108
func (r *rpcServer) listBalancesByAsset(ctx context.Context,
1109-
assetID *asset.ID) (*taprpc.ListBalancesResponse, error) {
1109+
assetID *asset.ID, includeLeased bool) (*taprpc.ListBalancesResponse,
1110+
error) {
11101111

1111-
balances, err := r.cfg.AssetStore.QueryBalancesByAsset(ctx, assetID)
1112+
balances, err := r.cfg.AssetStore.QueryBalancesByAsset(
1113+
ctx, assetID, includeLeased,
1114+
)
11121115
if err != nil {
11131116
return nil, fmt.Errorf("unable to list balances: %w", err)
11141117
}
@@ -1138,10 +1141,11 @@ func (r *rpcServer) listBalancesByAsset(ctx context.Context,
11381141
}
11391142

11401143
func (r *rpcServer) listBalancesByGroupKey(ctx context.Context,
1141-
groupKey *btcec.PublicKey) (*taprpc.ListBalancesResponse, error) {
1144+
groupKey *btcec.PublicKey,
1145+
includeLeased bool) (*taprpc.ListBalancesResponse, error) {
11421146

11431147
balances, err := r.cfg.AssetStore.QueryAssetBalancesByGroup(
1144-
ctx, groupKey,
1148+
ctx, groupKey, includeLeased,
11451149
)
11461150
if err != nil {
11471151
return nil, fmt.Errorf("unable to list balances: %w", err)
@@ -1293,7 +1297,7 @@ func (r *rpcServer) ListBalances(ctx context.Context,
12931297
copy(assetID[:], req.AssetFilter)
12941298
}
12951299

1296-
return r.listBalancesByAsset(ctx, assetID)
1300+
return r.listBalancesByAsset(ctx, assetID, req.IncludeLeased)
12971301

12981302
case *taprpc.ListBalancesRequest_GroupKey:
12991303
if !groupBy.GroupKey {
@@ -1310,7 +1314,9 @@ func (r *rpcServer) ListBalances(ctx context.Context,
13101314
}
13111315
}
13121316

1313-
return r.listBalancesByGroupKey(ctx, groupKey)
1317+
return r.listBalancesByGroupKey(
1318+
ctx, groupKey, req.IncludeLeased,
1319+
)
13141320

13151321
default:
13161322
return nil, fmt.Errorf("invalid group_by")

tapdb/asset_minting_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,7 +1131,7 @@ func TestCommitBatchChainActions(t *testing.T) {
11311131

11321132
// We'll now query for the set of balances to ensure they all line up
11331133
// with the assets we just created, including the group genesis asset.
1134-
assetBalances, err := confAssets.QueryBalancesByAsset(ctx, nil)
1134+
assetBalances, err := confAssets.QueryBalancesByAsset(ctx, nil, false)
11351135
require.NoError(t, err)
11361136
require.Equal(t, numSeedlings+1, len(assetBalances))
11371137

@@ -1153,7 +1153,7 @@ func TestCommitBatchChainActions(t *testing.T) {
11531153
}
11541154
numKeyGroups := fn.Reduce(mintedAssets, keyGroupSumReducer)
11551155
assetBalancesByGroup, err := confAssets.QueryAssetBalancesByGroup(
1156-
ctx, nil,
1156+
ctx, nil, false,
11571157
)
11581158
require.NoError(t, err)
11591159
require.Equal(t, numKeyGroups, len(assetBalancesByGroup))

tapdb/assets_store.go

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,16 @@ type (
7676
// or for things like coin selection.
7777
QueryAssetFilters = sqlc.QueryAssetsParams
7878

79+
// QueryAssetBalancesByGroupFilters lets us query the asset balances for
80+
// asset groups or alternatively for a selected one that matches the
81+
// passed filter.
82+
QueryAssetBalancesByGroupFilters = sqlc.QueryAssetBalancesByGroupParams
83+
84+
// QueryAssetBalancesByAssetFilters lets us query the asset balances for
85+
// assets or alternatively for a selected one that matches the passed
86+
// filter.
87+
QueryAssetBalancesByAssetFilters = sqlc.QueryAssetBalancesByAssetParams
88+
7989
// UtxoQuery lets us query a managed UTXO by either the transaction it
8090
// references, or the outpoint.
8191
UtxoQuery = sqlc.FetchManagedUTXOParams
@@ -184,14 +194,15 @@ type ActiveAssetsStore interface {
184194
// QueryAssetBalancesByAsset queries the balances for assets or
185195
// alternatively for a selected one that matches the passed asset ID
186196
// filter.
187-
QueryAssetBalancesByAsset(context.Context, []byte) ([]RawAssetBalance,
188-
error)
197+
QueryAssetBalancesByAsset(context.Context,
198+
QueryAssetBalancesByAssetFilters) ([]RawAssetBalance, error)
189199

190200
// QueryAssetBalancesByGroup queries the asset balances for asset
191201
// groups or alternatively for a selected one that matches the passed
192202
// filter.
193203
QueryAssetBalancesByGroup(context.Context,
194-
[]byte) ([]RawAssetGroupBalance, error)
204+
QueryAssetBalancesByGroupFilters) ([]RawAssetGroupBalance,
205+
error)
195206

196207
// FetchGroupedAssets fetches all assets with non-nil group keys.
197208
FetchGroupedAssets(context.Context) ([]RawGroupedAsset, error)
@@ -960,18 +971,35 @@ type AssetQueryFilters struct {
960971
// QueryBalancesByAsset queries the balances for assets or alternatively
961972
// for a selected one that matches the passed asset ID filter.
962973
func (a *AssetStore) QueryBalancesByAsset(ctx context.Context,
963-
assetID *asset.ID) (map[asset.ID]AssetBalance, error) {
974+
assetID *asset.ID,
975+
includeLeased bool) (map[asset.ID]AssetBalance, error) {
976+
977+
// We'll now map the application level filtering to the type of
978+
// filtering our database query understands.
979+
assetBalancesFilter := QueryAssetBalancesByAssetFilters{
980+
Now: sql.NullTime{
981+
Time: a.clock.Now().UTC(),
982+
Valid: true,
983+
},
984+
}
964985

965-
var assetFilter []byte
986+
// By default, we only show assets that are not leased.
987+
if !includeLeased {
988+
assetBalancesFilter.Leased = sqlBool(false)
989+
}
990+
991+
// Only show assets that match the filter that has been passed
966992
if assetID != nil {
967-
assetFilter = assetID[:]
993+
assetBalancesFilter.AssetIDFilter = assetID[:]
968994
}
969995

970996
balances := make(map[asset.ID]AssetBalance)
971997

972998
readOpts := NewAssetStoreReadTx()
973999
dbErr := a.db.ExecTx(ctx, &readOpts, func(q ActiveAssetsStore) error {
974-
dbBalances, err := q.QueryAssetBalancesByAsset(ctx, assetFilter)
1000+
dbBalances, err := q.QueryAssetBalancesByAsset(
1001+
ctx, assetBalancesFilter,
1002+
)
9751003
if err != nil {
9761004
return fmt.Errorf("unable to query asset "+
9771005
"balances by asset: %w", err)
@@ -1014,20 +1042,37 @@ func (a *AssetStore) QueryBalancesByAsset(ctx context.Context,
10141042
// QueryAssetBalancesByGroup queries the asset balances for asset groups or
10151043
// alternatively for a selected one that matches the passed filter.
10161044
func (a *AssetStore) QueryAssetBalancesByGroup(ctx context.Context,
1017-
groupKey *btcec.PublicKey) (map[asset.SerializedKey]AssetGroupBalance,
1045+
groupKey *btcec.PublicKey,
1046+
includeLeased bool) (map[asset.SerializedKey]AssetGroupBalance,
10181047
error) {
10191048

1020-
var groupFilter []byte
1049+
// We'll now map the application level filtering to the type of
1050+
// filtering our database query understands.
1051+
assetBalancesFilter := QueryAssetBalancesByGroupFilters{
1052+
Now: sql.NullTime{
1053+
Time: a.clock.Now().UTC(),
1054+
Valid: true,
1055+
},
1056+
}
1057+
1058+
// By default, we only show assets that are not leased.
1059+
if !includeLeased {
1060+
assetBalancesFilter.Leased = sqlBool(false)
1061+
}
1062+
1063+
// Only show specific group if a groupKey has been passed.
10211064
if groupKey != nil {
10221065
groupKeySerialized := groupKey.SerializeCompressed()
1023-
groupFilter = groupKeySerialized[:]
1066+
assetBalancesFilter.KeyGroupFilter = groupKeySerialized[:]
10241067
}
10251068

10261069
balances := make(map[asset.SerializedKey]AssetGroupBalance)
10271070

10281071
readOpts := NewAssetStoreReadTx()
10291072
dbErr := a.db.ExecTx(ctx, &readOpts, func(q ActiveAssetsStore) error {
1030-
dbBalances, err := q.QueryAssetBalancesByGroup(ctx, groupFilter)
1073+
dbBalances, err := q.QueryAssetBalancesByGroup(
1074+
ctx, assetBalancesFilter,
1075+
)
10311076
if err != nil {
10321077
return fmt.Errorf("unable to query asset "+
10331078
"balances by asset: %w", err)

0 commit comments

Comments
 (0)