diff --git a/ledger/babbage/rules.go b/ledger/babbage/rules.go index c6e66661..3c3f7f0d 100644 --- a/ledger/babbage/rules.go +++ b/ledger/babbage/rules.go @@ -154,10 +154,8 @@ func UtxoValidateCollateralContainsNonAda( collReturn := tx.CollateralReturn() if collReturn != nil { collReturnAssets := collReturn.Assets() - if collReturnAssets != nil { - if collReturnAssets.Compare(&totalAssets) { - return nil - } + if (&totalAssets).Compare(collReturnAssets) { + return nil } } return alonzo.CollateralContainsNonAdaError{ diff --git a/ledger/babbage/rules_test.go b/ledger/babbage/rules_test.go index c1341e86..a2a01be8 100644 --- a/ledger/babbage/rules_test.go +++ b/ledger/babbage/rules_test.go @@ -1118,6 +1118,13 @@ func TestUtxoValidateCollateralContainsNonAda(t *testing.T) { }, }, ) + tmpZeroMultiAsset := common.NewMultiAsset[common.MultiAssetTypeOutput]( + map[common.Blake2b224]map[cbor.ByteString]uint64{ + common.Blake2b224Hash([]byte("abcd")): map[cbor.ByteString]uint64{ + cbor.NewByteString([]byte("efgh")): 0, + }, + }, + ) testLedgerState := test.MockLedgerState{ MockUtxos: []common.Utxo{ { @@ -1135,6 +1142,15 @@ func TestUtxoValidateCollateralContainsNonAda(t *testing.T) { }, }, }, + { + Id: shelley.NewShelleyTransactionInput(testInputTxId, 2), + Output: babbage.BabbageTransactionOutput{ + OutputAmount: mary.MaryTransactionOutputValue{ + Amount: testCollateralAmount, + Assets: &tmpZeroMultiAsset, + }, + }, + }, }, } testSlot := uint64(0) @@ -1219,6 +1235,32 @@ func TestUtxoValidateCollateralContainsNonAda(t *testing.T) { } }, ) + // Coin and zero assets with return + t.Run( + "coin and zero assets with return", + func(t *testing.T) { + testTx.Body.TxCollateral = []shelley.ShelleyTransactionInput{ + shelley.NewShelleyTransactionInput(testInputTxId, 2), + } + testTx.Body.TxCollateralReturn = &babbage.BabbageTransactionOutput{ + OutputAmount: mary.MaryTransactionOutputValue{ + Amount: testCollateralAmount, + }, + } + err := babbage.UtxoValidateCollateralContainsNonAda( + testTx, + testSlot, + testLedgerState, + testProtocolParams, + ) + if err != nil { + t.Errorf( + "UtxoValidateCollateralContainsNonAda should succeed when collateral with only coin is provided\n got error: %v", + err, + ) + } + }, + ) } func TestUtxoValidateNoCollateralInputs(t *testing.T) { diff --git a/ledger/common/common.go b/ledger/common/common.go index 69112cde..e40ca5c0 100644 --- a/ledger/common/common.go +++ b/ledger/common/common.go @@ -229,17 +229,20 @@ func (m *MultiAsset[T]) Add(assets *MultiAsset[T]) { } func (m *MultiAsset[T]) Compare(assets *MultiAsset[T]) bool { - if assets == nil { - return false - } - if len(assets.data) != len(m.data) { + // Normalize data for easier comparison + tmpData := m.normalize() + otherData := assets.normalize() + // Compare policy counts + if len(otherData) != len(tmpData) { return false } - for policy, assets := range assets.data { - if len(assets) != len(m.data[policy]) { + for policy, assets := range otherData { + // Compare asset counts for policy + if len(assets) != len(tmpData[policy]) { return false } for asset, amount := range assets { + // Compare quantity of specific asset if amount != m.Asset(policy, asset.Bytes()) { return false } @@ -248,6 +251,24 @@ func (m *MultiAsset[T]) Compare(assets *MultiAsset[T]) bool { return true } +func (m *MultiAsset[T]) normalize() map[Blake2b224]map[cbor.ByteString]T { + ret := map[Blake2b224]map[cbor.ByteString]T{} + if m == nil || m.data == nil { + return ret + } + for policy, assets := range m.data { + for asset, amount := range assets { + if amount != 0 { + if _, ok := ret[policy]; !ok { + ret[policy] = make(map[cbor.ByteString]T) + } + ret[policy][asset] = amount + } + } + } + return ret +} + type AssetFingerprint struct { policyId []byte assetName []byte diff --git a/ledger/common/common_test.go b/ledger/common/common_test.go index a4eba4a8..5b7a03a4 100644 --- a/ledger/common/common_test.go +++ b/ledger/common/common_test.go @@ -126,6 +126,84 @@ func TestMultiAssetJson(t *testing.T) { } } +func TestMultiAssetCompare(t *testing.T) { + testDefs := []struct { + asset1 *MultiAsset[MultiAssetTypeOutput] + asset2 *MultiAsset[MultiAssetTypeOutput] + expectedResult bool + }{ + { + asset1: &MultiAsset[MultiAssetTypeOutput]{ + data: map[Blake2b224]map[cbor.ByteString]MultiAssetTypeOutput{ + NewBlake2b224([]byte("abcd")): { + cbor.NewByteString([]byte("cdef")): 123, + }, + }, + }, + asset2: &MultiAsset[MultiAssetTypeOutput]{ + data: map[Blake2b224]map[cbor.ByteString]MultiAssetTypeOutput{ + NewBlake2b224([]byte("abcd")): { + cbor.NewByteString([]byte("cdef")): 123, + }, + }, + }, + expectedResult: true, + }, + { + asset1: &MultiAsset[MultiAssetTypeOutput]{ + data: map[Blake2b224]map[cbor.ByteString]MultiAssetTypeOutput{ + NewBlake2b224([]byte("abcd")): { + cbor.NewByteString([]byte("cdef")): 123, + }, + }, + }, + asset2: &MultiAsset[MultiAssetTypeOutput]{ + data: map[Blake2b224]map[cbor.ByteString]MultiAssetTypeOutput{ + NewBlake2b224([]byte("abcd")): { + cbor.NewByteString([]byte("cdef")): 124, + }, + }, + }, + expectedResult: false, + }, + { + asset1: &MultiAsset[MultiAssetTypeOutput]{ + data: map[Blake2b224]map[cbor.ByteString]MultiAssetTypeOutput{ + NewBlake2b224([]byte("abcd")): { + cbor.NewByteString([]byte("cdef")): 0, + }, + }, + }, + asset2: nil, + expectedResult: true, + }, + { + asset1: &MultiAsset[MultiAssetTypeOutput]{ + data: map[Blake2b224]map[cbor.ByteString]MultiAssetTypeOutput{ + NewBlake2b224([]byte("abcd")): { + cbor.NewByteString([]byte("cdef")): 123, + }, + }, + }, + asset2: &MultiAsset[MultiAssetTypeOutput]{ + data: map[Blake2b224]map[cbor.ByteString]MultiAssetTypeOutput{ + NewBlake2b224([]byte("abcd")): { + cbor.NewByteString([]byte("cdef")): 123, + cbor.NewByteString([]byte("efgh")): 123, + }, + }, + }, + expectedResult: false, + }, + } + for _, testDef := range testDefs { + tmpResult := testDef.asset1.Compare(testDef.asset2) + if tmpResult != testDef.expectedResult { + t.Errorf("did not get expected result: got %v, wanted %v", tmpResult, testDef.expectedResult) + } + } +} + // Test the MarshalJSON method for Blake2b224 to ensure it properly converts to JSON. func TestBlake2b224_MarshalJSON(t *testing.T) { // Example data to represent Blake2b224 hash diff --git a/ledger/conway/rules.go b/ledger/conway/rules.go index 31fd6104..de12c639 100644 --- a/ledger/conway/rules.go +++ b/ledger/conway/rules.go @@ -179,10 +179,8 @@ func UtxoValidateCollateralContainsNonAda( collReturn := tx.CollateralReturn() if collReturn != nil { collReturnAssets := collReturn.Assets() - if collReturnAssets != nil { - if collReturnAssets.Compare(&totalAssets) { - return nil - } + if (&totalAssets).Compare(collReturnAssets) { + return nil } } return alonzo.CollateralContainsNonAdaError{ diff --git a/ledger/conway/rules_test.go b/ledger/conway/rules_test.go index ae3150e6..ee1aad61 100644 --- a/ledger/conway/rules_test.go +++ b/ledger/conway/rules_test.go @@ -1128,6 +1128,13 @@ func TestUtxoValidateCollateralContainsNonAda(t *testing.T) { }, }, ) + tmpZeroMultiAsset := common.NewMultiAsset[common.MultiAssetTypeOutput]( + map[common.Blake2b224]map[cbor.ByteString]uint64{ + common.Blake2b224Hash([]byte("abcd")): map[cbor.ByteString]uint64{ + cbor.NewByteString([]byte("efgh")): 0, + }, + }, + ) testLedgerState := test.MockLedgerState{ MockUtxos: []common.Utxo{ { @@ -1145,6 +1152,15 @@ func TestUtxoValidateCollateralContainsNonAda(t *testing.T) { }, }, }, + { + Id: shelley.NewShelleyTransactionInput(testInputTxId, 2), + Output: babbage.BabbageTransactionOutput{ + OutputAmount: mary.MaryTransactionOutputValue{ + Amount: testCollateralAmount, + Assets: &tmpZeroMultiAsset, + }, + }, + }, }, } testSlot := uint64(0) @@ -1229,6 +1245,32 @@ func TestUtxoValidateCollateralContainsNonAda(t *testing.T) { } }, ) + // Coin and zero assets with return + t.Run( + "coin and zero assets with return", + func(t *testing.T) { + testTx.Body.TxCollateral = []shelley.ShelleyTransactionInput{ + shelley.NewShelleyTransactionInput(testInputTxId, 2), + } + testTx.Body.TxCollateralReturn = &babbage.BabbageTransactionOutput{ + OutputAmount: mary.MaryTransactionOutputValue{ + Amount: testCollateralAmount, + }, + } + err := conway.UtxoValidateCollateralContainsNonAda( + testTx, + testSlot, + testLedgerState, + testProtocolParams, + ) + if err != nil { + t.Errorf( + "UtxoValidateCollateralContainsNonAda should succeed when collateral with only coin is provided\n got error: %v", + err, + ) + } + }, + ) } func TestUtxoValidateNoCollateralInputs(t *testing.T) {