diff --git a/protocol/localstatequery/localstatequery.go b/protocol/localstatequery/localstatequery.go index 2888e3e7..f13d97a5 100644 --- a/protocol/localstatequery/localstatequery.go +++ b/protocol/localstatequery/localstatequery.go @@ -155,7 +155,7 @@ type CallbackContext struct { // Callback function types type AcquireFunc func(CallbackContext, AcquireTarget, bool) error -type QueryFunc func(CallbackContext, any) (any, error) +type QueryFunc func(CallbackContext, QueryWrapper) (any, error) type ReleaseFunc func(CallbackContext) error // New returns a new LocalStateQuery object diff --git a/protocol/localstatequery/messages.go b/protocol/localstatequery/messages.go index ab0c1729..7252b189 100644 --- a/protocol/localstatequery/messages.go +++ b/protocol/localstatequery/messages.go @@ -154,7 +154,7 @@ func NewMsgFailure(failure uint8) *MsgFailure { type MsgQuery struct { protocol.MessageBase - Query interface{} + Query QueryWrapper } func NewMsgQuery(query interface{}) *MsgQuery { @@ -162,7 +162,9 @@ func NewMsgQuery(query interface{}) *MsgQuery { MessageBase: protocol.MessageBase{ MessageType: MessageTypeQuery, }, - Query: query, + Query: QueryWrapper{ + Query: query, + }, } return m } diff --git a/protocol/localstatequery/messages_test.go b/protocol/localstatequery/messages_test.go index 16e797d0..8538557b 100644 --- a/protocol/localstatequery/messages_test.go +++ b/protocol/localstatequery/messages_test.go @@ -16,11 +16,12 @@ package localstatequery import ( "encoding/hex" + "reflect" + "testing" + "github.com/blinklabs-io/gouroboros/cbor" "github.com/blinklabs-io/gouroboros/protocol" "github.com/blinklabs-io/gouroboros/protocol/common" - "reflect" - "testing" ) type testDefinition struct { @@ -49,12 +50,12 @@ var tests = []testDefinition{ CborHex: "8203820082028101", Message: NewMsgQuery( // Current era hard-fork query - []interface{}{ - uint64(0), - []interface{}{ - uint64(2), - []interface{}{ - uint64(1), + &BlockQuery{ + Query: &HardForkQuery{ + Query: &HardForkCurrentEraQuery{ + simpleQueryBase{ + Type: QueryTypeHardForkCurrentEra, + }, }, }, }, @@ -105,6 +106,9 @@ func TestDecode(t *testing.T) { } // Set the raw CBOR so the comparison should succeed test.Message.SetCbor(cborData) + if m, ok := msg.(*MsgQuery); ok { + m.Query.SetCbor(nil) + } if !reflect.DeepEqual(msg, test.Message) { t.Fatalf( "CBOR did not decode to expected message object\n got: %#v\n wanted: %#v", diff --git a/protocol/localstatequery/queries.go b/protocol/localstatequery/queries.go index 768f88db..b68a990e 100644 --- a/protocol/localstatequery/queries.go +++ b/protocol/localstatequery/queries.go @@ -61,6 +61,307 @@ const ( QueryTypeShelleyPoolDistr = 21 ) +// simpleQueryBase is a helper type used for various query types to reduce repeat code +type simpleQueryBase struct { + cbor.StructAsArray + Type int +} + +// QueryWrapper is used for decoding a query from CBOR +type QueryWrapper struct { + cbor.DecodeStoreCbor + Query any +} + +func (q *QueryWrapper) UnmarshalCBOR(data []byte) error { + // Store original CBOR + q.SetCbor(data) + // Decode query + tmpQuery, err := decodeQuery( + data, + "", + map[int]any{ + QueryTypeBlock: &BlockQuery{}, + QueryTypeSystemStart: &SystemStartQuery{}, + QueryTypeChainBlockNo: &ChainBlockNoQuery{}, + QueryTypeChainPoint: &ChainPointQuery{}, + }, + ) + if err != nil { + return err + } + q.Query = tmpQuery + return nil +} + +func (q *QueryWrapper) MarshalCBOR() ([]byte, error) { + return cbor.Encode(q.Query) +} + +type BlockQuery struct { + Query any +} + +func (q *BlockQuery) MarshalCBOR() ([]byte, error) { + tmpData := []any{ + QueryTypeBlock, + q.Query, + } + return cbor.Encode(tmpData) +} + +func (q *BlockQuery) UnmarshalCBOR(data []byte) error { + // Unwrap + tmpData := struct { + cbor.StructAsArray + Type int + SubQuery cbor.RawMessage + }{} + if _, err := cbor.Decode(data, &tmpData); err != nil { + return err + } + // Decode query + tmpQuery, err := decodeQuery( + tmpData.SubQuery, + "Block", + map[int]any{ + QueryTypeShelley: &ShelleyQuery{}, + QueryTypeHardFork: &HardForkQuery{}, + }, + ) + if err != nil { + return err + } + q.Query = tmpQuery + return nil +} + +type ShelleyQuery struct { + Era uint + Query any +} + +func (q *ShelleyQuery) MarshalCBOR() ([]byte, error) { + tmpData := []any{ + QueryTypeShelley, + []any{ + q.Era, + q.Query, + }, + } + return cbor.Encode(tmpData) +} + +func (q *ShelleyQuery) UnmarshalCBOR(data []byte) error { + // Unwrap + tmpData := struct { + cbor.StructAsArray + Type int + Inner struct { + cbor.StructAsArray + Era uint + SubQuery cbor.RawMessage + } + }{} + if _, err := cbor.Decode(data, &tmpData); err != nil { + return err + } + // Decode query + tmpQuery, err := decodeQuery( + tmpData.Inner.SubQuery, + "Block", + map[int]any{ + QueryTypeShelleyLedgerTip: &ShelleyLedgerTipQuery{}, + QueryTypeShelleyEpochNo: &ShelleyEpochNoQuery{}, + QueryTypeShelleyNonMyopicMemberRewards: &ShelleyNonMyopicMemberRewardsQuery{}, + QueryTypeShelleyCurrentProtocolParams: &ShelleyCurrentProtocolParamsQuery{}, + QueryTypeShelleyProposedProtocolParamsUpdates: &ShelleyProposedProtocolParamsUpdatesQuery{}, + QueryTypeShelleyStakeDistribution: &ShelleyStakeDistributionQuery{}, + QueryTypeShelleyUtxoByAddress: &ShelleyUtxoByAddressQuery{}, + QueryTypeShelleyUtxoWhole: &ShelleyUtxoWholeQuery{}, + QueryTypeShelleyDebugEpochState: &ShelleyDebugEpochStateQuery{}, + QueryTypeShelleyCbor: &ShelleyCborQuery{}, + QueryTypeShelleyFilteredDelegationAndRewardAccounts: &ShelleyFilteredDelegationAndRewardAccountsQuery{}, + QueryTypeShelleyGenesisConfig: &ShelleyGenesisConfigQuery{}, + QueryTypeShelleyDebugNewEpochState: &ShelleyDebugNewEpochStateQuery{}, + QueryTypeShelleyDebugChainDepState: &ShelleyDebugChainDepStateQuery{}, + QueryTypeShelleyRewardProvenance: &ShelleyRewardProvenanceQuery{}, + QueryTypeShelleyUtxoByTxin: &ShelleyUtxoByTxinQuery{}, + QueryTypeShelleyStakePools: &ShelleyStakePoolsQuery{}, + QueryTypeShelleyStakePoolParams: &ShelleyStakePoolParamsQuery{}, + QueryTypeShelleyRewardInfoPools: &ShelleyRewardInfoPoolsQuery{}, + QueryTypeShelleyPoolState: &ShelleyPoolStateQuery{}, + QueryTypeShelleyStakeSnapshots: &ShelleyStakeSnapshotsQuery{}, + QueryTypeShelleyPoolDistr: &ShelleyPoolDistrQuery{}, + }, + ) + if err != nil { + return err + } + q.Era = tmpData.Inner.Era + q.Query = tmpQuery + return nil +} + +type HardForkQuery struct { + Query any +} + +func (q *HardForkQuery) MarshalCBOR() ([]byte, error) { + tmpData := []any{ + QueryTypeHardFork, + q.Query, + } + return cbor.Encode(tmpData) +} + +func (q *HardForkQuery) UnmarshalCBOR(data []byte) error { + // Unwrap + tmpData := struct { + cbor.StructAsArray + Type int + SubQuery cbor.RawMessage + }{} + if _, err := cbor.Decode(data, &tmpData); err != nil { + return err + } + // Decode query + tmpQuery, err := decodeQuery( + tmpData.SubQuery, + "Hard-fork", + map[int]any{ + QueryTypeHardForkEraHistory: &HardForkEraHistoryQuery{}, + QueryTypeHardForkCurrentEra: &HardForkCurrentEraQuery{}, + }, + ) + if err != nil { + return err + } + q.Query = tmpQuery + return nil +} + +type ShelleyLedgerTipQuery struct { + simpleQueryBase +} + +type ShelleyEpochNoQuery struct { + simpleQueryBase +} + +type ShelleyNonMyopicMemberRewardsQuery struct { + simpleQueryBase +} + +type ShelleyCurrentProtocolParamsQuery struct { + simpleQueryBase +} + +type ShelleyProposedProtocolParamsUpdatesQuery struct { + simpleQueryBase +} + +type ShelleyStakeDistributionQuery struct { + simpleQueryBase +} + +type ShelleyUtxoByAddressQuery struct { + cbor.StructAsArray + Type int + Addrs []ledger.Address +} + +type ShelleyUtxoWholeQuery struct { + simpleQueryBase +} + +type ShelleyDebugEpochStateQuery struct { + simpleQueryBase +} + +type ShelleyCborQuery struct { + simpleQueryBase +} + +type ShelleyFilteredDelegationAndRewardAccountsQuery struct { + simpleQueryBase + // TODO: add params +} + +type ShelleyGenesisConfigQuery struct { + simpleQueryBase +} + +type ShelleyDebugNewEpochStateQuery struct { + simpleQueryBase +} + +type ShelleyDebugChainDepStateQuery struct { + simpleQueryBase +} + +type ShelleyRewardProvenanceQuery struct { + simpleQueryBase +} + +type ShelleyUtxoByTxinQuery struct { + cbor.StructAsArray + Type int + TxIns []ledger.TransactionInput +} + +type ShelleyStakePoolsQuery struct { + simpleQueryBase +} + +type ShelleyStakePoolParamsQuery struct { + simpleQueryBase + // TODO: add params +} + +type ShelleyRewardInfoPoolsQuery struct { + simpleQueryBase +} + +type ShelleyPoolStateQuery struct { + simpleQueryBase +} + +type ShelleyStakeSnapshotsQuery struct { + simpleQueryBase +} + +type ShelleyPoolDistrQuery struct { + simpleQueryBase +} + +func decodeQuery(data []byte, typeDesc string, queryTypes map[int]any) (any, error) { + // Determine query type + queryType, err := cbor.DecodeIdFromList(data) + if err != nil { + return nil, err + } + var tmpQuery any + for typeId, queryObj := range queryTypes { + if queryType == typeId { + tmpQuery = queryObj + break + } + } + if tmpQuery == nil { + errMsg := "unknown query type" + if typeDesc != "" { + errMsg = fmt.Sprintf("unknown %s query type", typeDesc) + } + return nil, fmt.Errorf("%s: %d", errMsg, queryType) + } + // Decode query + if _, err := cbor.Decode(data, tmpQuery); err != nil { + return nil, err + } + return tmpQuery, nil +} + func buildQuery(queryType int, params ...interface{}) []interface{} { ret := []interface{}{queryType} if len(params) > 0 { @@ -104,6 +405,10 @@ func buildShelleyQuery( return ret } +type SystemStartQuery struct { + simpleQueryBase +} + type SystemStartResult struct { // Tells the CBOR decoder to convert to/from a struct and a CBOR array _ struct{} `cbor:",toarray"` @@ -112,6 +417,22 @@ type SystemStartResult struct { Picoseconds uint64 } +type ChainBlockNoQuery struct { + simpleQueryBase +} + +type ChainPointQuery struct { + simpleQueryBase +} + +type HardForkCurrentEraQuery struct { + simpleQueryBase +} + +type HardForkEraHistoryQuery struct { + simpleQueryBase +} + type EraHistoryResult struct { // Tells the CBOR decoder to convert to/from a struct and a CBOR array _ struct{} `cbor:",toarray"`