diff --git a/go.mod b/go.mod index b127f37f09..447c9d81e9 100644 --- a/go.mod +++ b/go.mod @@ -68,7 +68,7 @@ replace ( github.com/gogo/protobuf => github.com/regen-network/protobuf v1.3.3-alpha.regen.1 // use cometBFT system fork of tendermint with akash patches - github.com/tendermint/tendermint => github.com/akash-network/cometbft v0.34.27-akash.2 + github.com/tendermint/tendermint => github.com/akash-network/cometbft v0.34.27-akash.3 github.com/zondax/hid => github.com/troian/hid v0.13.2 github.com/zondax/ledger-go => github.com/akash-network/ledger-go v0.14.3 diff --git a/go.sum b/go.sum index 7084817548..bcdb66cc99 100644 --- a/go.sum +++ b/go.sum @@ -78,8 +78,8 @@ github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= github.com/akash-network/akash-api v0.0.75 h1:h9RZemWa7JqMGYb3nVRhRgP4xZnACIy0yN7de60JLyg= github.com/akash-network/akash-api v0.0.75/go.mod h1:pvoHHEQbt63+U+HUSTjssZ1nUJ8sJuWtHCu6ztaXcqo= -github.com/akash-network/cometbft v0.34.27-akash.2 h1:2hKEcX+cIv/OLAJ82gBWdkZlVWn+8JUYs4GrDoPAOhU= -github.com/akash-network/cometbft v0.34.27-akash.2/go.mod h1:BcCbhKv7ieM0KEddnYXvQZR+pZykTKReJJYf7YC7qhw= +github.com/akash-network/cometbft v0.34.27-akash.3 h1:ObmkKrMybIuRLPcwPwUMJ8Pllsr+Gsve443mkJsonMA= +github.com/akash-network/cometbft v0.34.27-akash.3/go.mod h1:BcCbhKv7ieM0KEddnYXvQZR+pZykTKReJJYf7YC7qhw= github.com/akash-network/cosmos-sdk v0.45.16-akash.3 h1:QiHOQ1ACzCvAEXRlzGNQhp9quWLOowE104D0uESGrEk= github.com/akash-network/cosmos-sdk v0.45.16-akash.3/go.mod h1:NTnk/GuQdFyfk/iGFxDAgQH9fwcbRW/hREap6qaPg48= github.com/akash-network/ledger-go v0.14.3 h1:LCEFkTfgGA2xFMN2CtiKvXKE7dh0QSM77PJHCpSkaAo= diff --git a/tests/upgrade/upgrade_test.go b/tests/upgrade/upgrade_test.go index 2ef1c6846f..969afae3ea 100644 --- a/tests/upgrade/upgrade_test.go +++ b/tests/upgrade/upgrade_test.go @@ -254,6 +254,7 @@ type nodeInitParams struct { grpcPort uint16 grpcWebPort uint16 pprofPort uint16 + apiPort uint16 } var ( @@ -446,6 +447,7 @@ func TestUpgrade(t *testing.T) { grpcPort: 9090 + uint16(idx*3), grpcWebPort: 9091 + uint16(idx*3), pprofPort: 6060 + uint16(idx), + apiPort: 1317 + uint16(idx), } } @@ -488,6 +490,7 @@ func TestUpgrade(t *testing.T) { fmt.Sprintf("AKASH_RPC_PPROF_LADDR=%s:%d", listenAddr, params.pprofPort), fmt.Sprintf("AKASH_GRPC_ADDRESS=%s:%d", listenAddr, params.grpcPort), fmt.Sprintf("AKASH_GRPC_WEB_ADDRESS=%s:%d", listenAddr, params.grpcWebPort), + fmt.Sprintf("AKASH_API_ADDRESS=tcp://%s:%d", listenAddr, params.apiPort), "DAEMON_NAME=akash", "DAEMON_RESTART_AFTER_UPGRADE=true", "DAEMON_ALLOW_DOWNLOAD_BINARIES=true", @@ -508,6 +511,7 @@ func TestUpgrade(t *testing.T) { "AKASH_TX_INDEX_INDEXER=null", "AKASH_GRPC_ENABLE=true", "AKASH_GRPC_WEB_ENABLE=true", + "AKASH_API_ENABLE=true", }, } } diff --git a/util/query/pagination.go b/util/query/pagination.go new file mode 100644 index 0000000000..a0a6bf4deb --- /dev/null +++ b/util/query/pagination.go @@ -0,0 +1,143 @@ +package query + +import ( + "encoding/binary" + "fmt" + "hash/crc32" + + "github.com/akash-network/node/util/validation" +) + +var ( + ErrInvalidPaginationKey = fmt.Errorf("pagination: invalid key") +) + +// DecodePaginationKey parses the pagination key and returns the states, prefix and key to be used by the FilteredPaginate +func DecodePaginationKey(key []byte) ([]byte, []byte, []byte, []byte, error) { + if len(key) < 5 { + return nil, nil, nil, nil, fmt.Errorf("%w: invalid key length", ErrInvalidPaginationKey) + } + + expectedChecksum := binary.BigEndian.Uint32(key) + + key = key[4:] + + checksum := crc32.ChecksumIEEE(key) + + if expectedChecksum != checksum { + return nil, nil, nil, nil, fmt.Errorf("%w: invalid checksum, 0x%08x != 0x%08x", ErrInvalidPaginationKey, expectedChecksum, checksum) + } + + statesC := int(key[0]) + key = key[1:] + + if len(key) < statesC { + return nil, nil, nil, nil, fmt.Errorf("%w: invalid state length", ErrInvalidPaginationKey) + } + + states := make([]byte, 0, statesC) + for _, state := range key[:statesC] { + states = append(states, state) + } + + key = key[len(states):] + + if len(key) < 1 { + return nil, nil, nil, nil, fmt.Errorf("%w: invalid state length", ErrInvalidPaginationKey) + } + + prefixLength := int(key[0]) + + key = key[1:] + + if len(key) < prefixLength { + return nil, nil, nil, nil, fmt.Errorf("%w: invalid state length", ErrInvalidPaginationKey) + } + + prefix := key[:prefixLength] + + key = key[prefixLength:] + + if len(key) == 0 { + return nil, nil, nil, nil, fmt.Errorf("%w: invalid state length", ErrInvalidPaginationKey) + } + + keyLength := int(key[0]) + key = key[1:] + + if len(key) < keyLength { + return nil, nil, nil, nil, fmt.Errorf("%w: invalid state length", ErrInvalidPaginationKey) + } + + pkey := key[:keyLength] + + key = key[keyLength:] + var unsolicited []byte + + if len(key) > 0 { + keyLength = int(key[0]) + key = key[1:] + + if len(key) != keyLength { + return nil, nil, nil, nil, fmt.Errorf("%w: invalid state length", ErrInvalidPaginationKey) + } + + unsolicited = key + } + + return states, prefix, pkey, unsolicited, nil +} + +func EncodePaginationKey(states, prefix, key, unsolicited []byte) ([]byte, error) { + if len(states) == 0 { + return nil, fmt.Errorf("%w: states cannot be empty", ErrInvalidPaginationKey) + } + + if len(prefix) == 0 { + return nil, fmt.Errorf("%w: prefix cannot be empty", ErrInvalidPaginationKey) + } + + if len(key) == 0 { + return nil, fmt.Errorf("%w: key cannot be empty", ErrInvalidPaginationKey) + } + + // 4 bytes for checksum + // 1 byte for states count + // len(states) bytes for states + // 1 byte for prefix length + // len(prefix) bytes for prefix + // 1 byte for key length + // len(key) bytes for key + encLen := 4 + 1 + len(states) + 1 + len(prefix) + 1 + len(key) + + if len(unsolicited) > 0 { + encLen += 1 + len(unsolicited) + } + + buf := make([]byte, encLen) + + data := buf[4:] + + tmp := validation.MustEncodeWithLengthPrefix(states) + copy(data, tmp) + + offset := len(tmp) + tmp = validation.MustEncodeWithLengthPrefix(prefix) + + copy(data[offset:], tmp) + offset += len(tmp) + + tmp = validation.MustEncodeWithLengthPrefix(key) + copy(data[offset:], tmp) + + if len(unsolicited) > 0 { + offset += len(tmp) + tmp = validation.MustEncodeWithLengthPrefix(unsolicited) + copy(data[offset:], tmp) + } + + checksum := crc32.ChecksumIEEE(data) + binary.BigEndian.PutUint32(buf, checksum) + + return buf, nil +} diff --git a/util/query/pagination_test.go b/util/query/pagination_test.go new file mode 100644 index 0000000000..672d7358c7 --- /dev/null +++ b/util/query/pagination_test.go @@ -0,0 +1,361 @@ +package query + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestEncodePaginationKey(t *testing.T) { + type testCase struct { + name string + states []byte + prefix []byte + key []byte + unsolicited []byte + wantOutput []byte + wantErr bool + } + tests := []testCase{ + { + name: "fail/all params empty", + states: nil, + prefix: nil, + key: nil, + unsolicited: nil, + wantOutput: nil, + wantErr: true, + }, + { + name: "fail/key is empty", + states: []byte{1}, + prefix: []byte{2}, + key: nil, + unsolicited: nil, + wantOutput: nil, + wantErr: true, + }, + { + name: "fail/prefix is empty", + states: []byte{1}, + prefix: nil, + key: []byte{3}, + unsolicited: nil, + wantOutput: nil, + wantErr: true, + }, + { + name: "fail/states is empty", + states: nil, + prefix: []byte{2}, + key: []byte{3}, + unsolicited: nil, + wantOutput: nil, + wantErr: true, + }, + { + name: "pass/all params valid", + states: []byte{1}, + prefix: []byte{2}, + key: []byte{3}, + unsolicited: nil, + wantOutput: []byte{0x7c, 0xd4, 0x88, 0x46, 0x1, 0x1, 0x1, 0x2, 0x1, 0x3}, + wantErr: false, + }, + { + name: "pass/all params valid with unsolicited", + states: []byte{1}, + prefix: []byte{2}, + key: []byte{3}, + unsolicited: []byte{4}, + wantOutput: []byte{0x1a, 0xef, 0x78, 0xe2, 0x1, 0x1, 0x1, 0x2, 0x1, 0x3, 0x1, 0x4}, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotOutput, err := EncodePaginationKey(tt.states, tt.prefix, tt.key, tt.unsolicited) + if tt.wantErr { + require.Error(t, err) + require.Nil(t, gotOutput) + } else { + require.NoError(t, err) + require.NotNil(t, gotOutput) + require.Equal(t, tt.wantOutput, gotOutput) + } + }) + } +} + +func TestDecodePaginationKey(t *testing.T) { + type testCase struct { + name string + input []byte + wantStates []byte + wantPrefix []byte + wantKey []byte + wantUnsol []byte + wantErr bool + wantErrString string + } + + tests := []testCase{ + { + name: "fail/too short key", + input: []byte{0x01, 0x02, 0x03, 0x04}, + wantErr: true, + }, + { + name: "fail/invalid checksum", + input: []byte{0x01, 0x02, 0x03, 0x04, 0x01, 65}, + wantErr: true, + wantErrString: "pagination: invalid key: invalid checksum, 0x01020304 != 0x591952b8", + }, + { + name: "fail/invalid states length", + input: []byte{0xA5, 0x05, 0xDF, 0x1B, 1}, + wantErr: true, + wantErrString: "pagination: invalid key: invalid state length", + }, + { + name: "fail/invalid prefix length", + input: []byte{0x90, 0x9F, 0xB2, 0xF2, 0x01, 0x01, 0x01}, + wantErr: true, + wantErrString: "pagination: invalid key: invalid state length", + }, + { + name: "fail/invalid key length", + input: []byte{0x07, 0x0D, 0x81, 0xEB, 0x01, 0x01, 0x01, 0x02, 0x01}, + wantErr: true, + wantErrString: "pagination: invalid key: invalid state length", + }, + { + name: "fail/invalid unsolicited length", + input: []byte{0x3A, 0xC6, 0xEF, 0x36, 0x01, 0x01, 0x01, 0x02, 0x01, 0x03, 0x01}, + wantErr: true, + wantErrString: "pagination: invalid key: invalid state length", + }, + { + name: "pass/without unsolicited", + input: makeTestKey(t, []byte{1}, []byte{2}, []byte{3}, nil), + wantStates: []byte{1}, + wantPrefix: []byte{2}, + wantKey: []byte{3}, + wantUnsol: nil, + wantErr: false, + wantErrString: "", + }, + { + name: "pass/key with unsolicited", + input: makeTestKey(t, []byte{1}, []byte{2}, []byte{3}, []byte{4}), + wantStates: []byte{1}, + wantPrefix: []byte{2}, + wantKey: []byte{3}, + wantUnsol: []byte{4}, + wantErr: false, + wantErrString: "", + }, + { + name: "pass/multiple states", + input: makeTestKey(t, []byte{1, 7}, []byte{2}, []byte{3}, nil), + wantStates: []byte{1, 7}, + wantPrefix: []byte{2}, + wantKey: []byte{3}, + wantUnsol: nil, + wantErr: false, + wantErrString: "", + }, + { + name: "pass/key with multiple bytes", + input: makeTestKey(t, []byte{1, 7}, []byte{2, 29, 1}, []byte{3}, nil), + wantStates: []byte{1, 7}, + wantPrefix: []byte{2, 29, 1}, + wantKey: []byte{3}, + wantUnsol: nil, + wantErr: false, + wantErrString: "", + }, + { + name: "pass/unsolicited with multiple bytes", + input: makeTestKey(t, []byte{1, 7}, []byte{2, 29, 1}, []byte{3, 2, 17}, nil), + wantStates: []byte{1, 7}, + wantPrefix: []byte{2, 29, 1}, + wantKey: []byte{3, 2, 17}, + wantUnsol: nil, + wantErr: false, + wantErrString: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotStates, gotPrefix, gotKey, gotUnsol, err := DecodePaginationKey(tt.input) + + if tt.wantErr { + require.Error(t, err, "DecodePaginationKey() expected error but got none") + + if tt.wantErrString != "" { + require.Equal(t, tt.wantErrString, err.Error(), "DecodePaginationKey() unexpected error string") + } + + require.Nil(t, gotStates, "DecodePaginationKey() expected states to be nil") + require.Nil(t, gotPrefix, "DecodePaginationKey() expected prefix to be nil") + require.Nil(t, gotKey, "DecodePaginationKey() expected key to be nil") + require.Nil(t, gotUnsol, "DecodePaginationKey() expected unsolicited to be nil") + + return + } + + require.NoError(t, err, "DecodePaginationKey() unexpected error") + + require.Equal(t, tt.wantStates, gotStates, "DecodePaginationKey() unexpected states") + require.Equal(t, tt.wantPrefix, gotPrefix, "DecodePaginationKey() unexpected prefix") + require.Equal(t, tt.wantKey, gotKey, "DecodePaginationKey() unexpected key") + require.Equal(t, tt.wantUnsol, gotUnsol, "DecodePaginationKey() unexpected unsolicited") + }) + } + + // tests := []struct { + // name string + // input []byte + // wantStates []byte + // wantPrefix []byte + // wantKey []byte + // wantUnsol []byte + // wantErr bool + // wantErrString string + // }{ + // + // { + // name: "invalid states length", + // input: makeTestKey([]byte{5}, []byte{}, []byte{}, nil), + // wantErr: true, + // wantErrString: "pagination: invalid key: invalid state length", + // }, + // { + // name: "valid key without unsolicited", + // input: makeTestKey([]byte{1, 2}, []byte{3, 4}, []byte{5, 6}, nil), + // wantStates: []byte{1, 2}, + // wantPrefix: []byte{3, 4}, + // wantKey: []byte{5, 6}, + // wantUnsol: nil, + // wantErr: false, + // }, + // { + // name: "valid key with unsolicited", + // input: makeTestKey([]byte{1, 2}, []byte{3, 4}, []byte{5, 6}, []byte{7, 8}), + // wantStates: []byte{1, 2}, + // wantPrefix: []byte{3, 4}, + // wantKey: []byte{5, 6}, + // wantUnsol: []byte{7, 8}, + // wantErr: false, + // }, + // { + // name: "manually constructed valid checksum", + // input: func() []byte { + // payload := []byte{ + // 2, // states count + // 1, 2, // states + // 2, // prefix length + // 3, 4, // prefix + // 2, // key length + // 5, 6, // key + // } + // checksum := crc32.ChecksumIEEE(payload) + // result := make([]byte, 4+len(payload)) + // binary.BigEndian.PutUint32(result, checksum) + // copy(result[4:], payload) + // return result + // }(), + // wantStates: []byte{1, 2}, + // wantPrefix: []byte{3, 4}, + // wantKey: []byte{5, 6}, + // wantUnsol: nil, + // wantErr: false, + // }, + // { + // name: "corrupted first byte of checksum", + // input: func() []byte { + // payload := []byte{ + // 2, // states count + // 1, 2, // states + // 2, // prefix length + // 3, 4, // prefix + // 2, // key length + // 5, 6, // key + // } + // checksum := crc32.ChecksumIEEE(payload) + // result := make([]byte, 4+len(payload)) + // binary.BigEndian.PutUint32(result, checksum) + // copy(result[4:], payload) + // result[0]++ // corrupt first byte of checksum + // return result + // }(), + // wantErr: true, + // wantErrString: "pagination: invalid key: invalid checksum", + // }, + // { + // name: "corrupted last byte of checksum", + // input: func() []byte { + // payload := []byte{ + // 2, // states count + // 1, 2, // states + // 2, // prefix length + // 3, 4, // prefix + // 2, // key length + // 5, 6, // key + // } + // checksum := crc32.ChecksumIEEE(payload) + // result := make([]byte, 4+len(payload)) + // binary.BigEndian.PutUint32(result, checksum) + // copy(result[4:], payload) + // result[3]++ // corrupt last byte of checksum + // return result + // }(), + // wantErr: true, + // wantErrString: "pagination: invalid key: invalid checksum", + // }, + // { + // name: "corrupted payload with valid checksum", + // input: func() []byte { + // payload := []byte{ + // 2, // states count + // 1, 2, // states + // 2, // prefix length + // 3, 4, // prefix + // 2, // key length + // 5, 6, // key + // } + // checksum := crc32.ChecksumIEEE(payload) + // result := make([]byte, 4+len(payload)) + // binary.BigEndian.PutUint32(result, checksum) + // copy(result[4:], payload) + // result[5]++ // corrupt payload after checksum + // return result + // }(), + // wantErr: true, + // wantErrString: "pagination: invalid key: invalid checksum", + // }, + // } +} + +// makeTestKey is a helper function to create a valid pagination key for testing +func makeTestKey(t *testing.T, states, prefix, key, unsolicited []byte) []byte { + if len(states) == 0 { + t.Fatal("states cannot be empty") + } + if len(prefix) == 0 { + t.Fatal("prefix cannot be empty") + } + if len(key) == 0 { + t.Fatal("key cannot be empty") + } + + encoded, err := EncodePaginationKey(states, prefix, key, unsolicited) + if err != nil { + t.Fatalf("failed to encode pagination key: %v", err) + } + + return encoded +} diff --git a/util/validation/address.go b/util/validation/address.go index e11b19d1fa..671059de66 100644 --- a/util/validation/address.go +++ b/util/validation/address.go @@ -35,3 +35,20 @@ func AssertKeyLength(bz []byte, length int) { panic(err) } } + +func EncodeWithLengthPrefix(bz []byte) ([]byte, error) { + if len(bz) > 255 { + return nil, fmt.Errorf("length-prefixed address too long") + } + + return append([]byte{byte(len(bz))}, bz...), nil +} + +func MustEncodeWithLengthPrefix(bz []byte) []byte { + res, err := EncodeWithLengthPrefix(bz) + if err != nil { + panic(err) + } + + return res +} diff --git a/x/cert/keeper/grpc_query.go b/x/cert/keeper/grpc_query.go index 1e2e2d09bd..598086f111 100644 --- a/x/cert/keeper/grpc_query.go +++ b/x/cert/keeper/grpc_query.go @@ -10,6 +10,8 @@ import ( "google.golang.org/grpc/status" types "github.com/akash-network/akash-api/go/node/cert/v1beta3" + + "github.com/akash-network/node/util/query" ) // Querier is used as Keeper will have duplicate methods if used directly, and gRPC names take precedence over keeper @@ -26,12 +28,6 @@ func (q querier) Certificates(c context.Context, req *types.QueryCertificatesReq return nil, status.Error(codes.InvalidArgument, "empty request") } - stateVal := types.Certificate_State(types.Certificate_State_value[req.Filter.State]) - - if req.Filter.State != "" && stateVal == types.CertificateStateInvalid { - return nil, status.Error(codes.InvalidArgument, "invalid state value") - } - if req.Pagination == nil { req.Pagination = &sdkquery.PageRequest{} } else if req.Pagination != nil && req.Pagination.Offset > 0 && req.Filter.State == "" { @@ -42,32 +38,33 @@ func (q querier) Certificates(c context.Context, req *types.QueryCertificatesReq req.Pagination.Limit = sdkquery.DefaultLimit } - states := make([]types.Certificate_State, 0, 2) + states := make([]byte, 0, 2) + + var searchPrefix []byte // setup for case 3 - cross-index search - if req.Filter.State == "" { - // request has pagination key set, determine store prefix - if len(req.Pagination.Key) > 0 { - if len(req.Pagination.Key) < 3 { - return nil, status.Error(codes.InvalidArgument, "invalid pagination key") - } + // nolint: gocritic + if len(req.Pagination.Key) > 0 { + var key []byte + var err error + states, searchPrefix, key, _, err = query.DecodePaginationKey(req.Pagination.Key) + if err != nil { + return nil, status.Error(codes.Internal, err.Error()) + } - switch req.Pagination.Key[2] { - case CertStateValidPrefixID: - states = append(states, types.CertificateValid) - fallthrough - case CertStateRevokedPrefixID: - states = append(states, types.CertificateRevoked) - default: - return nil, status.Error(codes.InvalidArgument, "invalid pagination key") - } - } else { - // request does not have pagination set. Start from valid store - states = append(states, types.CertificateValid) - states = append(states, types.CertificateRevoked) + req.Pagination.Key = key + } else if req.Filter.State != "" { + stateVal := types.Certificate_State(types.Certificate_State_value[req.Filter.State]) + + if req.Filter.State != "" && stateVal == types.CertificateStateInvalid { + return nil, status.Error(codes.InvalidArgument, "invalid state value") } + + states = append(states, byte(stateVal)) } else { - states = append(states, stateVal) + // request does not have pagination set. Start from valid store + states = append(states, byte(types.CertificateValid)) + states = append(states, byte(types.CertificateRevoked)) } var certificates types.CertificatesResponse @@ -75,15 +72,21 @@ func (q querier) Certificates(c context.Context, req *types.QueryCertificatesReq total := uint64(0) - for _, state := range states { - var searchPrefix []byte + for idx := range states { + state := types.Certificate_State(states[idx]) var err error - req.Filter.State = state.String() + if idx > 0 { + req.Pagination.Key = nil + } - searchPrefix, err = filterToPrefix(req.Filter) - if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + if len(req.Pagination.Key) == 0 { + req.Filter.State = state.String() + + searchPrefix, err = filterToPrefix(req.Filter) + if err != nil { + return nil, status.Error(codes.Internal, err.Error()) + } } searchStore := prefix.NewStore(ctx.KVStore(q.skey), searchPrefix) @@ -107,8 +110,6 @@ func (q querier) Certificates(c context.Context, req *types.QueryCertificatesReq certificates = append(certificates, item) count++ - - // return true, nil } return true, nil @@ -121,6 +122,18 @@ func (q querier) Certificates(c context.Context, req *types.QueryCertificatesReq total += count if req.Pagination.Limit == 0 { + if len(pageRes.NextKey) > 0 { + pageRes.NextKey, err = query.EncodePaginationKey(states[idx:], searchPrefix, pageRes.NextKey, nil) + if err != nil { + pageRes.Total = total + + return &types.QueryCertificatesResponse{ + Certificates: certificates, + Pagination: pageRes, + }, status.Error(codes.Internal, err.Error()) + } + } + break } } diff --git a/x/cert/keeper/grpc_query_test.go b/x/cert/keeper/grpc_query_test.go index 3e500308b1..cb1fc98d69 100644 --- a/x/cert/keeper/grpc_query_test.go +++ b/x/cert/keeper/grpc_query_test.go @@ -43,7 +43,11 @@ func setupTest(t *testing.T) *grpcTestSuite { func sortCerts(certs types.Certificates) { sort.SliceStable(certs, func(i, j int) bool { - return certs[i].State < certs[j].State + if certs[i].State < certs[j].State { + return true + } + + return string(certs[i].Cert) < string(certs[j].Cert) }) } @@ -54,7 +58,9 @@ func TestCertGRPCQueryCertificates(t *testing.T) { cert := testutil.Certificate(t, owner) owner2 := testutil.AccAddress(t) + owner3 := testutil.AccAddress(t) cert2 := testutil.Certificate(t, owner2) + cert3 := testutil.Certificate(t, owner3) err := suite.keeper.CreateCertificate(suite.ctx, owner, cert.PEM.Cert, cert.PEM.Pub) require.NoError(t, err) @@ -62,6 +68,9 @@ func TestCertGRPCQueryCertificates(t *testing.T) { err = suite.keeper.CreateCertificate(suite.ctx, owner2, cert2.PEM.Cert, cert2.PEM.Pub) require.NoError(t, err) + err = suite.keeper.CreateCertificate(suite.ctx, owner3, cert3.PEM.Cert, cert3.PEM.Pub) + require.NoError(t, err) + err = suite.keeper.RevokeCertificate(suite.ctx, types.CertID{ Owner: owner2, Serial: cert2.Serial, @@ -75,6 +84,7 @@ func TestCertGRPCQueryCertificates(t *testing.T) { msg string malleate func() expPass bool + nextKey bool }{ { "all certificates", @@ -86,6 +96,11 @@ func TestCertGRPCQueryCertificates(t *testing.T) { Cert: cert.PEM.Cert, Pubkey: cert.PEM.Pub, }, + types.Certificate{ + State: types.CertificateValid, + Cert: cert3.PEM.Cert, + Pubkey: cert3.PEM.Pub, + }, types.Certificate{ State: types.CertificateRevoked, Cert: cert2.PEM.Cert, @@ -94,6 +109,7 @@ func TestCertGRPCQueryCertificates(t *testing.T) { } }, true, + false, }, { "certificate not found", @@ -107,6 +123,7 @@ func TestCertGRPCQueryCertificates(t *testing.T) { expCertificates = nil }, false, + false, }, { "success valid", @@ -125,9 +142,10 @@ func TestCertGRPCQueryCertificates(t *testing.T) { } }, true, + false, }, { - "success revoked", + "success revoked by owner", func() { req = &types.QueryCertificatesRequest{ Filter: types.CertificateFilter{ @@ -143,7 +161,28 @@ func TestCertGRPCQueryCertificates(t *testing.T) { } }, true, + false, + }, + { + "success revoked by state", + func() { + req = &types.QueryCertificatesRequest{ + Filter: types.CertificateFilter{ + State: types.CertificateRevoked.String(), + }, + } + expCertificates = types.Certificates{ + types.Certificate{ + State: types.CertificateRevoked, + Cert: cert2.PEM.Cert, + Pubkey: cert2.PEM.Pub, + }, + } + }, + true, + false, }, + { "success pagination with limit", func() { @@ -158,6 +197,11 @@ func TestCertGRPCQueryCertificates(t *testing.T) { Cert: cert.PEM.Cert, Pubkey: cert.PEM.Pub, }, + types.Certificate{ + State: types.CertificateValid, + Cert: cert3.PEM.Cert, + Pubkey: cert3.PEM.Pub, + }, types.Certificate{ State: types.CertificateRevoked, Cert: cert2.PEM.Cert, @@ -166,6 +210,49 @@ func TestCertGRPCQueryCertificates(t *testing.T) { } }, true, + false, + }, + + // { + // "success pagination with next key", + // func() { + // req = &types.QueryCertificatesRequest{ + // Filter: types.CertificateFilter{State: types.CertificateValid.String()}, + // Pagination: &sdkquery.PageRequest{ + // Limit: 1, + // }, + // } + // expCertificates = types.Certificates{ + // types.Certificate{ + // State: types.CertificateValid, + // Cert: cert.PEM.Cert, + // Pubkey: cert.PEM.Pub, + // }, + // } + // }, + // true, + // true, + // }, + + { + "success pagination with nil key", + func() { + req = &types.QueryCertificatesRequest{ + Filter: types.CertificateFilter{State: types.CertificateRevoked.String()}, + Pagination: &sdkquery.PageRequest{ + Limit: 1, + }, + } + expCertificates = types.Certificates{ + types.Certificate{ + State: types.CertificateRevoked, + Cert: cert2.PEM.Cert, + Pubkey: cert2.PEM.Pub, + }, + } + }, + true, + false, }, { "success pagination with limit with state", @@ -184,9 +271,15 @@ func TestCertGRPCQueryCertificates(t *testing.T) { Cert: cert.PEM.Cert, Pubkey: cert.PEM.Pub, }, + types.Certificate{ + State: types.CertificateValid, + Cert: cert3.PEM.Cert, + Pubkey: cert3.PEM.Pub, + }, } }, true, + false, }, { "success pagination with limit with owner", @@ -208,6 +301,7 @@ func TestCertGRPCQueryCertificates(t *testing.T) { } }, true, + false, }, { "failing pagination with limit with non-existing owner", @@ -223,6 +317,7 @@ func TestCertGRPCQueryCertificates(t *testing.T) { expCertificates = nil }, false, + false, }, } @@ -239,13 +334,36 @@ func TestCertGRPCQueryCertificates(t *testing.T) { if expCertificates != nil { sortCerts(expCertificates) - respCerts := make(types.Certificates, len(res.Certificates)) - for i, cert := range res.Certificates { - respCerts[i] = cert.Certificate + respCerts := make(types.Certificates, 0, len(res.Certificates)) + for _, cert := range res.Certificates { + respCerts = append(respCerts, cert.Certificate) } sortCerts(respCerts) - require.Equal(t, expCertificates, respCerts) + + if req.Pagination != nil && req.Pagination.Limit > 0 { + require.LessOrEqual(t, len(respCerts), int(req.Pagination.Limit)) + } + + require.Len(t, respCerts, len(expCertificates)) + + for i, cert := range expCertificates { + require.Equal(t, cert, respCerts[i]) + } + } + + if tc.nextKey { + require.NotNil(t, res.Pagination.NextKey) + + req.Pagination.Key = res.Pagination.NextKey + res, err = suite.qclient.Certificates(ctx, req) + require.NoError(t, err) + require.NotNil(t, res) + if req.Pagination != nil && req.Pagination.Limit > 0 { + require.LessOrEqual(t, len(res.Certificates), int(req.Pagination.Limit)) + } + + require.Nil(t, res.Pagination.NextKey) } } else { require.NotNil(t, res) diff --git a/x/deployment/keeper/grpc_query.go b/x/deployment/keeper/grpc_query.go index becfe5b4fa..4864abb079 100644 --- a/x/deployment/keeper/grpc_query.go +++ b/x/deployment/keeper/grpc_query.go @@ -12,6 +12,8 @@ import ( sdkquery "github.com/cosmos/cosmos-sdk/types/query" types "github.com/akash-network/akash-api/go/node/deployment/v1beta3" + + "github.com/akash-network/node/util/query" ) // Querier is used as Keeper will have duplicate methods if used directly, and gRPC names take precedence over keeper @@ -34,14 +36,6 @@ func (k Querier) Deployments(c context.Context, req *types.QueryDeploymentsReque return nil, status.Error(codes.InvalidArgument, "empty request") } - stateVal := types.Deployment_State(types.Deployment_State_value[req.Filters.State]) - - if req.Filters.State != "" && stateVal == types.DeploymentStateInvalid { - return nil, status.Error(codes.InvalidArgument, "invalid state value") - } - - var store prefix.Store - if req.Pagination == nil { req.Pagination = &sdkquery.PageRequest{} } else if req.Pagination != nil && req.Pagination.Offset > 0 && req.Filters.State == "" { @@ -52,32 +46,33 @@ func (k Querier) Deployments(c context.Context, req *types.QueryDeploymentsReque req.Pagination.Limit = sdkquery.DefaultLimit } - states := make([]types.Deployment_State, 0, 2) + states := make([]byte, 0, 2) + + var searchPrefix []byte // setup for case 3 - cross-index search - if req.Filters.State == "" { - // request has pagination key set, determine store prefix - if len(req.Pagination.Key) > 0 { - if len(req.Pagination.Key) < 3 { - return nil, status.Error(codes.InvalidArgument, "invalid pagination key") - } + // nolint: gocritic + if len(req.Pagination.Key) > 0 { + var key []byte + var err error + states, searchPrefix, key, _, err = query.DecodePaginationKey(req.Pagination.Key) + if err != nil { + return nil, status.Error(codes.Internal, err.Error()) + } - switch req.Pagination.Key[2] { - case DeploymentStateActivePrefixID: - states = append(states, types.DeploymentActive) - fallthrough - case DeploymentStateClosedPrefixID: - states = append(states, types.DeploymentClosed) - default: - return nil, status.Error(codes.InvalidArgument, "invalid pagination key") - } - } else { - // request does not have pagination set. Start from active store - states = append(states, types.DeploymentActive) - states = append(states, types.DeploymentClosed) + req.Pagination.Key = key + } else if req.Filters.State != "" { + stateVal := types.Deployment_State(types.Deployment_State_value[req.Filters.State]) + + if req.Filters.State != "" && stateVal == types.DeploymentStateInvalid { + return nil, status.Error(codes.InvalidArgument, "invalid state value") } + + states = append(states, byte(stateVal)) } else { - states = append(states, stateVal) + // request does not have pagination set. Start from active store + states = append(states, byte(types.DeploymentActive)) + states = append(states, byte(types.DeploymentClosed)) } var deployments types.DeploymentResponses @@ -85,22 +80,29 @@ func (k Querier) Deployments(c context.Context, req *types.QueryDeploymentsReque total := uint64(0) - for _, state := range states { - var searchPrefix []byte + for idx := range states { + state := types.Deployment_State(states[idx]) + var err error - req.Filters.State = state.String() + if idx > 0 { + req.Pagination.Key = nil + } - searchPrefix, err = deploymentPrefixFromFilter(req.Filters) - if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + if len(req.Pagination.Key) == 0 { + req.Filters.State = state.String() + + searchPrefix, err = deploymentPrefixFromFilter(req.Filters) + if err != nil { + return nil, status.Error(codes.Internal, err.Error()) + } } - store = prefix.NewStore(ctx.KVStore(k.skey), searchPrefix) + searchStore := prefix.NewStore(ctx.KVStore(k.skey), searchPrefix) count := uint64(0) - pageRes, err = sdkquery.FilteredPaginate(store, req.Pagination, func(key []byte, value []byte, accumulate bool) (bool, error) { + pageRes, err = sdkquery.FilteredPaginate(searchStore, req.Pagination, func(key []byte, value []byte, accumulate bool) (bool, error) { var deployment types.Deployment err := k.cdc.Unmarshal(value, &deployment) @@ -109,7 +111,7 @@ func (k Querier) Deployments(c context.Context, req *types.QueryDeploymentsReque } // filter deployments with provided filters - if req.Filters.Accept(deployment, stateVal) { + if req.Filters.Accept(deployment, state) { if accumulate { account, err := k.ekeeper.GetAccount( ctx, @@ -142,6 +144,17 @@ func (k Querier) Deployments(c context.Context, req *types.QueryDeploymentsReque total += count if req.Pagination.Limit == 0 { + if len(pageRes.NextKey) > 0 { + pageRes.NextKey, err = query.EncodePaginationKey(states[idx:], searchPrefix, pageRes.NextKey, nil) + if err != nil { + pageRes.Total = total + return &types.QueryDeploymentsResponse{ + Deployments: deployments, + Pagination: pageRes, + }, status.Error(codes.Internal, err.Error()) + } + } + break } } diff --git a/x/deployment/keeper/grpc_query_test.go b/x/deployment/keeper/grpc_query_test.go index b79e00edca..54fe311698 100644 --- a/x/deployment/keeper/grpc_query_test.go +++ b/x/deployment/keeper/grpc_query_test.go @@ -142,24 +142,44 @@ func TestGRPCQueryDeployments(t *testing.T) { suite.createEscrowAccount(deployment.ID()) deployment2, groups2 := suite.createDeployment() - deployment2.State = types.DeploymentClosed + deployment2.State = types.DeploymentActive err = suite.keeper.Create(suite.ctx, deployment2, groups2) require.NoError(t, err) suite.createEscrowAccount(deployment2.ID()) + deployment3, groups3 := suite.createDeployment() + deployment3.State = types.DeploymentClosed + err = suite.keeper.Create(suite.ctx, deployment3, groups3) + require.NoError(t, err) + suite.createEscrowAccount(deployment3.ID()) + var req *types.QueryDeploymentsRequest testCases := []struct { msg string malleate func() expLen int + nextKey bool }{ { "query deployments without any filters and pagination", func() { req = &types.QueryDeploymentsRequest{} }, + 3, + false, + }, + { + "query deployments with state filter", + func() { + req = &types.QueryDeploymentsRequest{ + Filters: types.DeploymentFilters{ + State: types.DeploymentActive.String(), + }, + } + }, 2, + false, }, { "query deployments with filters having non existent data", @@ -171,13 +191,16 @@ func TestGRPCQueryDeployments(t *testing.T) { }} }, 0, + false, }, { "query deployments with state filter", func() { - req = &types.QueryDeploymentsRequest{Filters: types.DeploymentFilters{State: types.DeploymentClosed.String()}} + req = &types.QueryDeploymentsRequest{ + Filters: types.DeploymentFilters{State: types.DeploymentClosed.String()}} }, 1, + false, }, { "query deployments with pagination", @@ -185,6 +208,18 @@ func TestGRPCQueryDeployments(t *testing.T) { req = &types.QueryDeploymentsRequest{Pagination: &sdkquery.PageRequest{Limit: 1}} }, 1, + false, + }, + { + "query deployments with pagination next key", + func() { + req = &types.QueryDeploymentsRequest{ + Filters: types.DeploymentFilters{State: types.DeploymentActive.String()}, + Pagination: &sdkquery.PageRequest{Limit: 1}, + } + }, + 1, + true, }, } @@ -197,7 +232,17 @@ func TestGRPCQueryDeployments(t *testing.T) { require.NoError(t, err) require.NotNil(t, res) - require.Equal(t, tc.expLen, len(res.Deployments)) + assert.Equal(t, tc.expLen, len(res.Deployments)) + + if tc.nextKey { + require.NotNil(t, res.Pagination.NextKey) + req.Pagination.Key = res.Pagination.NextKey + res, err = suite.queryClient.Deployments(ctx, req) + require.NoError(t, err) + require.NotNil(t, res) + assert.Nil(t, res.Pagination.NextKey) + assert.Equal(t, tc.expLen, len(res.Deployments)) + } }) } } diff --git a/x/market/keeper/grpc_query.go b/x/market/keeper/grpc_query.go index 0ce36f8d27..d3ab3db973 100644 --- a/x/market/keeper/grpc_query.go +++ b/x/market/keeper/grpc_query.go @@ -13,6 +13,7 @@ import ( dtypes "github.com/akash-network/akash-api/go/node/deployment/v1beta3" types "github.com/akash-network/akash-api/go/node/market/v1beta4" + "github.com/akash-network/node/util/query" keys "github.com/akash-network/node/x/market/keeper/keys/v1beta4" ) @@ -29,18 +30,6 @@ func (k Querier) Orders(c context.Context, req *types.QueryOrdersRequest) (*type return nil, status.Error(codes.InvalidArgument, "empty request") } - stateVal := types.Order_State(types.Order_State_value[req.Filters.State]) - - if req.Filters.State != "" && stateVal == types.OrderStateInvalid { - return nil, status.Error(codes.InvalidArgument, "invalid state value") - } - - // case 1: no filters set, iterating over entire store - // case 2: state only or state plus underlying filters like owner, iterating over state store - // case 3: state not set, underlying filters like owner are set, most complex case - - var store prefix.Store - if req.Pagination == nil { req.Pagination = &sdkquery.PageRequest{} } else if req.Pagination != nil && req.Pagination.Offset > 0 && req.Filters.State == "" { @@ -51,36 +40,37 @@ func (k Querier) Orders(c context.Context, req *types.QueryOrdersRequest) (*type req.Pagination.Limit = sdkquery.DefaultLimit } - states := make([]types.Order_State, 0, 3) + // case 1: no filters set, iterating over entire store + // case 2: state only or state plus underlying filters like owner, iterating over state store + // case 3: state not set, underlying filters like owner are set, most complex case + + states := make([]byte, 0, 3) + var searchPrefix []byte // setup for case 3 - cross-index search - if req.Filters.State == "" { - // request has pagination key set, determine store prefix - if len(req.Pagination.Key) > 0 { - if len(req.Pagination.Key) < 3 { - return nil, status.Error(codes.InvalidArgument, "invalid pagination key") - } + // nolint: gocritic + if len(req.Pagination.Key) > 0 { + var key []byte + var err error + states, searchPrefix, key, _, err = query.DecodePaginationKey(req.Pagination.Key) + if err != nil { + return nil, status.Error(codes.Internal, err.Error()) + } - switch req.Pagination.Key[2] { - case keys.OrderStateOpenPrefixID: - states = append(states, types.OrderOpen) - fallthrough - case keys.OrderStateActivePrefixID: - states = append(states, types.OrderActive) - fallthrough - case keys.OrderStateClosedPrefixID: - states = append(states, types.OrderClosed) - default: - return nil, status.Error(codes.InvalidArgument, "invalid pagination key") - } - } else { - // request does not have pagination set. Start from open store - states = append(states, types.OrderOpen) - states = append(states, types.OrderActive) - states = append(states, types.OrderClosed) + req.Pagination.Key = key + } else if req.Filters.State != "" { + stateVal := types.Order_State(types.Order_State_value[req.Filters.State]) + + if req.Filters.State != "" && stateVal == types.OrderStateInvalid { + return nil, status.Error(codes.InvalidArgument, "invalid state value") } + + states = append(states, byte(stateVal)) } else { - states = append(states, stateVal) + // request does not have pagination set. Start from open store + states = append(states, byte(types.OrderOpen)) + states = append(states, byte(types.OrderActive)) + states = append(states, byte(types.OrderClosed)) } var orders types.Orders @@ -90,22 +80,28 @@ func (k Querier) Orders(c context.Context, req *types.QueryOrdersRequest) (*type total := uint64(0) - for _, state := range states { - var searchPrefix []byte + for idx := range states { + state := types.Order_State(states[idx]) var err error - req.Filters.State = state.String() + if idx > 0 { + req.Pagination.Key = nil + } + + if len(req.Pagination.Key) == 0 { + req.Filters.State = state.String() - searchPrefix, err = keys.OrderPrefixFromFilter(req.Filters) - if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + searchPrefix, err = keys.OrderPrefixFromFilter(req.Filters) + if err != nil { + return nil, status.Error(codes.Internal, err.Error()) + } } - store = prefix.NewStore(ctx.KVStore(k.skey), searchPrefix) + searchStore := prefix.NewStore(ctx.KVStore(k.skey), searchPrefix) count := uint64(0) - pageRes, err = sdkquery.FilteredPaginate(store, req.Pagination, func(key []byte, value []byte, accumulate bool) (bool, error) { + pageRes, err = sdkquery.FilteredPaginate(searchStore, req.Pagination, func(key []byte, value []byte, accumulate bool) (bool, error) { var order types.Order err := k.cdc.Unmarshal(value, &order) @@ -114,7 +110,7 @@ func (k Querier) Orders(c context.Context, req *types.QueryOrdersRequest) (*type } // filter orders with provided filters - if req.Filters.Accept(order, stateVal) { + if req.Filters.Accept(order, state) { if accumulate { orders = append(orders, order) count++ @@ -129,10 +125,29 @@ func (k Querier) Orders(c context.Context, req *types.QueryOrdersRequest) (*type return nil, status.Error(codes.Internal, err.Error()) } + if len(pageRes.NextKey) > 0 { + nextKey := make([]byte, len(searchPrefix)+len(pageRes.NextKey)) + copy(nextKey, searchPrefix) + copy(nextKey[len(searchPrefix):], pageRes.NextKey) + + pageRes.NextKey = nextKey + } + req.Pagination.Limit -= count total += count if req.Pagination.Limit == 0 { + if len(pageRes.NextKey) > 0 { + pageRes.NextKey, err = query.EncodePaginationKey(states[idx:], searchPrefix, pageRes.NextKey, nil) + if err != nil { + pageRes.Total = total + return &types.QueryOrdersResponse{ + Orders: orders, + Pagination: pageRes, + }, status.Error(codes.Internal, err.Error()) + } + } + break } } @@ -151,12 +166,6 @@ func (k Querier) Bids(c context.Context, req *types.QueryBidsRequest) (*types.Qu return nil, status.Error(codes.InvalidArgument, "empty request") } - stateVal := types.Bid_State(types.Bid_State_value[req.Filters.State]) - - if req.Filters.State != "" && stateVal == types.BidStateInvalid { - return nil, status.Error(codes.InvalidArgument, "invalid state value") - } - if req.Pagination == nil { req.Pagination = &sdkquery.PageRequest{} } else if req.Pagination != nil && req.Pagination.Offset > 0 && req.Filters.State == "" { @@ -167,42 +176,42 @@ func (k Querier) Bids(c context.Context, req *types.QueryBidsRequest) (*types.Qu req.Pagination.Limit = sdkquery.DefaultLimit } - reverseSearch := (req.Filters.Owner == "") && (req.Filters.Provider != "") - states := make([]types.Bid_State, 0, 4) + reverseSearch := false + states := make([]byte, 0, 4) + var searchPrefix []byte // setup for case 3 - cross-index search - if req.Filters.State == "" { - // request has pagination key set, determine store prefix - if len(req.Pagination.Key) > 0 { - if len(req.Pagination.Key) < 3 { - return nil, status.Error(codes.InvalidArgument, "invalid pagination key") - } + // nolint: gocritic + if len(req.Pagination.Key) > 0 { + var key []byte + var unsolicited []byte + var err error + states, searchPrefix, key, unsolicited, err = query.DecodePaginationKey(req.Pagination.Key) + if err != nil { + return nil, status.Error(codes.Internal, err.Error()) + } - if reverseSearch && req.Pagination.Key[2] > keys.BidStateActivePrefixID { - return nil, status.Error(codes.InvalidArgument, "invalid pagination key") - } + if len(unsolicited) != 1 { + return nil, status.Error(codes.InvalidArgument, "invalid pagination key") + } + req.Pagination.Key = key - switch req.Pagination.Key[2] { - case keys.BidStateOpenPrefixID: - states = append(states, types.BidOpen) - fallthrough - case keys.BidStateActivePrefixID: - states = append(states, types.BidActive) - fallthrough - case keys.BidStateLostPrefixID: - states = append(states, types.BidLost) - fallthrough - case keys.BidStateClosedPrefixID: - states = append(states, types.BidClosed) - default: - return nil, status.Error(codes.InvalidArgument, "invalid pagination key") - } - } else { - // request does not have pagination set. Start from open store - states = append(states, types.BidOpen, types.BidActive, types.BidLost, types.BidClosed) + if unsolicited[1] == 1 { + reverseSearch = true } + } else if req.Filters.State != "" { + reverseSearch = (req.Filters.Owner == "") && (req.Filters.Provider != "") + + stateVal := types.Bid_State(types.Bid_State_value[req.Filters.State]) + + if req.Filters.State != "" && stateVal == types.BidStateInvalid { + return nil, status.Error(codes.InvalidArgument, "invalid state value") + } + + states = append(states, byte(stateVal)) } else { - states = append(states, stateVal) + // request does not have pagination set. Start from open store + states = append(states, byte(types.BidOpen), byte(types.BidActive), byte(types.BidLost), byte(types.BidClosed)) } var bids []types.QueryBidResponse @@ -211,26 +220,32 @@ func (k Querier) Bids(c context.Context, req *types.QueryBidsRequest) (*types.Qu total := uint64(0) - for _, state := range states { - var searchPrefix []byte + for idx := range states { + state := types.Bid_State(states[idx]) var err error - req.Filters.State = state.String() - - if reverseSearch { - searchPrefix, err = keys.BidReversePrefixFromFilter(req.Filters) - } else { - searchPrefix, err = keys.BidPrefixFromFilter(req.Filters) + if idx > 0 { + req.Pagination.Key = nil } - if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + if len(req.Pagination.Key) == 0 { + req.Filters.State = state.String() + + if reverseSearch { + searchPrefix, err = keys.BidReversePrefixFromFilter(req.Filters) + } else { + searchPrefix, err = keys.BidPrefixFromFilter(req.Filters) + } + + if err != nil { + return nil, status.Error(codes.Internal, err.Error()) + } } count := uint64(0) + searchStore := prefix.NewStore(ctx.KVStore(k.skey), searchPrefix) - bidStore := prefix.NewStore(ctx.KVStore(k.skey), searchPrefix) - pageRes, err = sdkquery.FilteredPaginate(bidStore, req.Pagination, func(key []byte, value []byte, accumulate bool) (bool, error) { + pageRes, err = sdkquery.FilteredPaginate(searchStore, req.Pagination, func(key []byte, value []byte, accumulate bool) (bool, error) { var bid types.Bid err := k.cdc.Unmarshal(value, &bid) @@ -239,7 +254,7 @@ func (k Querier) Bids(c context.Context, req *types.QueryBidsRequest) (*types.Qu } // filter bids with provided filters - if req.Filters.Accept(bid, stateVal) { + if req.Filters.Accept(bid, state) { if accumulate { acct, err := k.ekeeper.GetAccount(ctx, types.EscrowAccountForBid(bid.BidID)) if err != nil { @@ -263,10 +278,35 @@ func (k Querier) Bids(c context.Context, req *types.QueryBidsRequest) (*types.Qu return nil, status.Error(codes.Internal, err.Error()) } + if len(pageRes.NextKey) > 0 { + nextKey := make([]byte, len(searchPrefix)+len(pageRes.NextKey)) + copy(nextKey, searchPrefix) + copy(nextKey[len(searchPrefix):], pageRes.NextKey) + + pageRes.NextKey = nextKey + } + req.Pagination.Limit -= count total += count if req.Pagination.Limit == 0 { + if len(pageRes.NextKey) > 0 { + unsolicited := make([]byte, 1) + unsolicited[0] = 0 + if reverseSearch { + unsolicited[0] = 1 + } + + pageRes.NextKey, err = query.EncodePaginationKey(states[idx:], searchPrefix, pageRes.NextKey, unsolicited) + if err != nil { + pageRes.Total = total + return &types.QueryBidsResponse{ + Bids: bids, + Pagination: pageRes, + }, status.Error(codes.Internal, err.Error()) + } + } + break } } @@ -285,12 +325,6 @@ func (k Querier) Leases(c context.Context, req *types.QueryLeasesRequest) (*type return nil, status.Error(codes.InvalidArgument, "empty request") } - stateVal := types.Lease_State(types.Lease_State_value[req.Filters.State]) - - if req.Filters.State != "" && stateVal == types.LeaseStateInvalid { - return nil, status.Error(codes.InvalidArgument, "invalid state value") - } - if req.Pagination == nil { req.Pagination = &sdkquery.PageRequest{} } else if req.Pagination != nil && req.Pagination.Offset > 0 && req.Filters.State == "" { @@ -301,41 +335,43 @@ func (k Querier) Leases(c context.Context, req *types.QueryLeasesRequest) (*type req.Pagination.Limit = sdkquery.DefaultLimit } - reverseSearch := (req.Filters.Owner == "") && (req.Filters.Provider != "") - - states := make([]types.Lease_State, 0, 3) + // setup for case 3 - cross-index search + reverseSearch := false + states := make([]byte, 0, 3) + var searchPrefix []byte // setup for case 3 - cross-index search - if req.Filters.State == "" { - // request has pagination key set, determine store prefix - if len(req.Pagination.Key) > 0 { - if len(req.Pagination.Key) < 3 { - return nil, status.Error(codes.InvalidArgument, "invalid pagination key") - } + // nolint: gocritic + if len(req.Pagination.Key) > 0 { + var key []byte + var unsolicited []byte + var err error + states, searchPrefix, key, unsolicited, err = query.DecodePaginationKey(req.Pagination.Key) + if err != nil { + return nil, status.Error(codes.Internal, err.Error()) + } - if reverseSearch && req.Pagination.Key[2] > keys.LeaseStateActivePrefixID { - return nil, status.Error(codes.InvalidArgument, "invalid pagination key") - } + if len(unsolicited) != 1 { + return nil, status.Error(codes.InvalidArgument, "invalid pagination key") + } + req.Pagination.Key = key - switch req.Pagination.Key[2] { - case keys.LeaseStateActivePrefixID: - states = append(states, types.LeaseActive) - fallthrough - case keys.LeaseStateInsufficientFundsPrefixID: - states = append(states, types.LeaseInsufficientFunds) - fallthrough - case keys.LeaseStateClosedPrefixID: - states = append(states, types.LeaseClosed) - default: - return nil, status.Error(codes.InvalidArgument, "invalid pagination key") - } - } else { - // request does not have pagination set. Start from open store - req.Filters.State = types.LeaseActive.String() - states = append(states, types.LeaseActive, types.LeaseInsufficientFunds, types.LeaseClosed) + if unsolicited[1] == 1 { + reverseSearch = true } + } else if req.Filters.State != "" { + reverseSearch = (req.Filters.Owner == "") && (req.Filters.Provider != "") + + stateVal := types.Lease_State(types.Lease_State_value[req.Filters.State]) + + if req.Filters.State != "" && stateVal == types.LeaseStateInvalid { + return nil, status.Error(codes.InvalidArgument, "invalid state value") + } + + states = append(states, byte(stateVal)) } else { - states = append(states, stateVal) + // request does not have pagination set. Start from open store + states = append(states, byte(types.LeaseActive), byte(types.LeaseInsufficientFunds), byte(types.LeaseClosed)) } var leases []types.QueryLeaseResponse @@ -344,20 +380,26 @@ func (k Querier) Leases(c context.Context, req *types.QueryLeasesRequest) (*type total := uint64(0) - for _, state := range states { - var searchPrefix []byte + for idx := range states { + state := types.Lease_State(states[idx]) var err error - req.Filters.State = state.String() - - if reverseSearch { - searchPrefix, err = keys.LeaseReversePrefixFromFilter(req.Filters) - } else { - searchPrefix, err = keys.LeasePrefixFromFilter(req.Filters) + if idx > 0 { + req.Pagination.Key = nil } - if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + if len(req.Pagination.Key) == 0 { + req.Filters.State = state.String() + + if reverseSearch { + searchPrefix, err = keys.LeaseReversePrefixFromFilter(req.Filters) + } else { + searchPrefix, err = keys.LeasePrefixFromFilter(req.Filters) + } + + if err != nil { + return nil, status.Error(codes.Internal, err.Error()) + } } searchedStore := prefix.NewStore(ctx.KVStore(k.skey), searchPrefix) @@ -373,7 +415,7 @@ func (k Querier) Leases(c context.Context, req *types.QueryLeasesRequest) (*type } // filter leases with provided filters - if req.Filters.Accept(lease, stateVal) { + if req.Filters.Accept(lease, state) { if accumulate { payment, err := k.ekeeper.GetPayment(ctx, dtypes.EscrowAccountForDeployment(lease.ID().DeploymentID()), @@ -403,6 +445,23 @@ func (k Querier) Leases(c context.Context, req *types.QueryLeasesRequest) (*type total += count if req.Pagination.Limit == 0 { + if len(pageRes.NextKey) > 0 { + unsolicited := make([]byte, 1) + unsolicited[0] = 0 + if reverseSearch { + unsolicited[0] = 1 + } + + pageRes.NextKey, err = query.EncodePaginationKey(states[idx:], searchPrefix, pageRes.NextKey, unsolicited) + if err != nil { + pageRes.Total = total + return &types.QueryLeasesResponse{ + Leases: leases, + Pagination: pageRes, + }, status.Error(codes.Internal, err.Error()) + } + } + break } }