diff --git a/ledger/babbage/rules.go b/ledger/babbage/rules.go index f013f9f7..0ec29b5c 100644 --- a/ledger/babbage/rules.go +++ b/ledger/babbage/rules.go @@ -109,13 +109,15 @@ func UtxoValidateCollateralContainsNonAda(tx common.Transaction, slot uint64, ls } badOutputs := []common.TransactionOutput{} var totalCollateral uint64 + totalAssets := common.NewMultiAsset[common.MultiAssetTypeOutput](nil) for _, collateralInput := range tx.Collateral() { utxo, err := ls.UtxoById(collateralInput) if err != nil { return err } totalCollateral += utxo.Output.Amount() - if utxo.Output.Assets() == nil { + totalAssets.Add(utxo.Output.Assets()) + if utxo.Output.Assets() == nil || len(utxo.Output.Assets().Policies()) == 0 { continue } badOutputs = append(badOutputs, utxo.Output) @@ -123,6 +125,13 @@ func UtxoValidateCollateralContainsNonAda(tx common.Transaction, slot uint64, ls if len(badOutputs) == 0 { return nil } + // Check if all collateral assets are accounted for in the collateral return + collReturn := tx.CollateralReturn() + if collReturn != nil { + if collReturn.Assets().Compare(&totalAssets) { + return nil + } + } return alonzo.CollateralContainsNonAdaError{ Provided: totalCollateral, } diff --git a/ledger/babbage/rules_test.go b/ledger/babbage/rules_test.go index b864fc08..0ca4fb55 100644 --- a/ledger/babbage/rules_test.go +++ b/ledger/babbage/rules_test.go @@ -1084,7 +1084,9 @@ func TestUtxoValidateCollateralContainsNonAda(t *testing.T) { testInputTxId := "d228b482a1aae768e4a796380f49e021d9c21f70d3c12cb186b188dedfc0ee22" var testCollateralAmount uint64 = 100000 testTx := &babbage.BabbageTransaction{ - Body: babbage.BabbageTransactionBody{}, + Body: babbage.BabbageTransactionBody{ + TxTotalCollateral: testCollateralAmount, + }, WitnessSet: babbage.BabbageTransactionWitnessSet{ AlonzoTransactionWitnessSet: alonzo.AlonzoTransactionWitnessSet{ WsRedeemers: []alonzo.AlonzoRedeemer{ @@ -1095,7 +1097,11 @@ func TestUtxoValidateCollateralContainsNonAda(t *testing.T) { }, } tmpMultiAsset := common.NewMultiAsset[common.MultiAssetTypeOutput]( - map[common.Blake2b224]map[cbor.ByteString]uint64{}, + map[common.Blake2b224]map[cbor.ByteString]uint64{ + common.Blake2b224Hash([]byte("abcd")): map[cbor.ByteString]uint64{ + cbor.NewByteString([]byte("efgh")): 123, + }, + }, ) testLedgerState := test.MockLedgerState{ MockUtxos: []common.Utxo{ @@ -1170,6 +1176,34 @@ func TestUtxoValidateCollateralContainsNonAda(t *testing.T) { } }, ) + // Coin and assets with return + t.Run( + "coin and assets with return", + func(t *testing.T) { + testTx.Body.TxCollateral = []shelley.ShelleyTransactionInput{ + shelley.NewShelleyTransactionInput(testInputTxId, 0), + shelley.NewShelleyTransactionInput(testInputTxId, 1), + } + testTx.Body.TxCollateralReturn = &babbage.BabbageTransactionOutput{ + OutputAmount: mary.MaryTransactionOutputValue{ + Amount: testCollateralAmount, + Assets: &tmpMultiAsset, + }, + } + err := babbage.UtxoValidateCollateralContainsNonAda( + testTx, + testSlot, + testLedgerState, + testProtocolParams, + ) + if err != nil { + t.Errorf( + "UtxoValidateCollateralContainsNonAda should succeed when collateral with only coin is provided\n got error: %v", + err, + ) + } + }, + ) } func TestUtxoValidateNoCollateralInputs(t *testing.T) { diff --git a/ledger/common/common.go b/ledger/common/common.go index 9728ee1f..69112cde 100644 --- a/ledger/common/common.go +++ b/ledger/common/common.go @@ -141,6 +141,9 @@ type MultiAsset[T MultiAssetTypeOutput | MultiAssetTypeMint] struct { func NewMultiAsset[T MultiAssetTypeOutput | MultiAssetTypeMint]( data map[Blake2b224]map[cbor.ByteString]T, ) MultiAsset[T] { + if data == nil { + data = make(map[Blake2b224]map[cbor.ByteString]T) + } return MultiAsset[T]{data: data} } @@ -210,6 +213,41 @@ func (m *MultiAsset[T]) Asset(policyId Blake2b224, assetName []byte) T { return policy[cbor.NewByteString(assetName)] } +func (m *MultiAsset[T]) Add(assets *MultiAsset[T]) { + if assets == nil { + return + } + for policy, assets := range assets.data { + for asset, amount := range assets { + newAmount := m.Asset(policy, asset.Bytes()) + amount + if _, ok := m.data[policy]; !ok { + m.data[policy] = make(map[cbor.ByteString]T) + } + m.data[policy][asset] = newAmount + } + } +} + +func (m *MultiAsset[T]) Compare(assets *MultiAsset[T]) bool { + if assets == nil { + return false + } + if len(assets.data) != len(m.data) { + return false + } + for policy, assets := range assets.data { + if len(assets) != len(m.data[policy]) { + return false + } + for asset, amount := range assets { + if amount != m.Asset(policy, asset.Bytes()) { + return false + } + } + } + return true +} + type AssetFingerprint struct { policyId []byte assetName []byte diff --git a/ledger/conway/rules.go b/ledger/conway/rules.go index 5fd70413..97d5c250 100644 --- a/ledger/conway/rules.go +++ b/ledger/conway/rules.go @@ -129,13 +129,15 @@ func UtxoValidateCollateralContainsNonAda(tx common.Transaction, slot uint64, ls } badOutputs := []common.TransactionOutput{} var totalCollateral uint64 + totalAssets := common.NewMultiAsset[common.MultiAssetTypeOutput](nil) for _, collateralInput := range tx.Collateral() { utxo, err := ls.UtxoById(collateralInput) if err != nil { return err } totalCollateral += utxo.Output.Amount() - if utxo.Output.Assets() == nil { + totalAssets.Add(utxo.Output.Assets()) + if utxo.Output.Assets() == nil || len(utxo.Output.Assets().Policies()) == 0 { continue } badOutputs = append(badOutputs, utxo.Output) @@ -143,6 +145,13 @@ func UtxoValidateCollateralContainsNonAda(tx common.Transaction, slot uint64, ls if len(badOutputs) == 0 { return nil } + // Check if all collateral assets are accounted for in the collateral return + collReturn := tx.CollateralReturn() + if collReturn != nil { + if collReturn.Assets().Compare(&totalAssets) { + return nil + } + } return alonzo.CollateralContainsNonAdaError{ Provided: totalCollateral, } diff --git a/ledger/conway/rules_test.go b/ledger/conway/rules_test.go index 5281f097..ea9d48de 100644 --- a/ledger/conway/rules_test.go +++ b/ledger/conway/rules_test.go @@ -1095,7 +1095,11 @@ func TestUtxoValidateCollateralContainsNonAda(t *testing.T) { testInputTxId := "d228b482a1aae768e4a796380f49e021d9c21f70d3c12cb186b188dedfc0ee22" var testCollateralAmount uint64 = 100000 testTx := &conway.ConwayTransaction{ - Body: conway.ConwayTransactionBody{}, + Body: conway.ConwayTransactionBody{ + BabbageTransactionBody: babbage.BabbageTransactionBody{ + TxTotalCollateral: testCollateralAmount, + }, + }, WitnessSet: conway.ConwayTransactionWitnessSet{ WsRedeemers: conway.ConwayRedeemers{ Redeemers: map[conway.ConwayRedeemerKey]conway.ConwayRedeemerValue{ @@ -1106,7 +1110,11 @@ func TestUtxoValidateCollateralContainsNonAda(t *testing.T) { }, } tmpMultiAsset := common.NewMultiAsset[common.MultiAssetTypeOutput]( - map[common.Blake2b224]map[cbor.ByteString]uint64{}, + map[common.Blake2b224]map[cbor.ByteString]uint64{ + common.Blake2b224Hash([]byte("abcd")): map[cbor.ByteString]uint64{ + cbor.NewByteString([]byte("efgh")): 123, + }, + }, ) testLedgerState := test.MockLedgerState{ MockUtxos: []common.Utxo{ @@ -1181,6 +1189,34 @@ func TestUtxoValidateCollateralContainsNonAda(t *testing.T) { } }, ) + // Coin and assets with return + t.Run( + "coin and assets with return", + func(t *testing.T) { + testTx.Body.TxCollateral = []shelley.ShelleyTransactionInput{ + shelley.NewShelleyTransactionInput(testInputTxId, 0), + shelley.NewShelleyTransactionInput(testInputTxId, 1), + } + testTx.Body.TxCollateralReturn = &babbage.BabbageTransactionOutput{ + OutputAmount: mary.MaryTransactionOutputValue{ + Amount: testCollateralAmount, + Assets: &tmpMultiAsset, + }, + } + err := conway.UtxoValidateCollateralContainsNonAda( + testTx, + testSlot, + testLedgerState, + testProtocolParams, + ) + if err != nil { + t.Errorf( + "UtxoValidateCollateralContainsNonAda should succeed when collateral with only coin is provided\n got error: %v", + err, + ) + } + }, + ) } func TestUtxoValidateNoCollateralInputs(t *testing.T) {