Skip to content

Commit 2ed1acd

Browse files
authored
Merge pull request #899 from multiversx/back-transfers-improvements
back transfers improvements
2 parents c1734c7 + 6543e4c commit 2ed1acd

File tree

5 files changed

+201
-10
lines changed

5 files changed

+201
-10
lines changed

vmhost/contexts/managedType.go

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ type managedTypesContext struct {
6464
type backTransfers struct {
6565
ESDTTransfers []*vmcommon.ESDTTransfer
6666
CallValue *big.Int
67+
LastIndex uint32
6768
}
6869

6970
type managedTypesState struct {
@@ -833,13 +834,20 @@ func (context *managedTypesContext) getKeyValueFromManagedMap(mMapHandle int32,
833834
}
834835

835836
// AddBackTransfers add transfers to back transfers structure
836-
func (context *managedTypesContext) AddBackTransfers(transfers []*vmcommon.ESDTTransfer) {
837-
context.managedTypesValues.backTransfers.ESDTTransfers = append(context.managedTypesValues.backTransfers.ESDTTransfers, transfers...)
838-
}
837+
func (context *managedTypesContext) AddBackTransfers(value *big.Int, transfers []*vmcommon.ESDTTransfer, index uint32) {
838+
backTrs := &context.managedTypesValues.backTransfers
839+
if context.host.EnableEpochsHandler().IsFlagEnabled(vmhost.FixBackTransferOPCODE) && backTrs.LastIndex >= index {
840+
return
841+
}
839842

840-
// AddValueOnlyBackTransfer add to back transfer value
841-
func (context *managedTypesContext) AddValueOnlyBackTransfer(value *big.Int) {
842-
context.managedTypesValues.backTransfers.CallValue.Add(context.managedTypesValues.backTransfers.CallValue, value)
843+
if backTrs.LastIndex < index {
844+
backTrs.LastIndex = index
845+
}
846+
847+
backTrs.CallValue.Add(backTrs.CallValue, value)
848+
if len(transfers) > 0 {
849+
backTrs.ESDTTransfers = append(backTrs.ESDTTransfers, transfers...)
850+
}
843851
}
844852

845853
// GetBackTransfers returns all ESDT transfers and accumulated value as well, will clean accumulated values
@@ -848,6 +856,7 @@ func (context *managedTypesContext) GetBackTransfers() ([]*vmcommon.ESDTTransfer
848856
context.managedTypesValues.backTransfers = backTransfers{
849857
ESDTTransfers: make([]*vmcommon.ESDTTransfer, 0),
850858
CallValue: big.NewInt(0),
859+
LastIndex: clonedTransfers.LastIndex,
851860
}
852861

853862
return clonedTransfers.ESDTTransfers, clonedTransfers.CallValue
@@ -857,6 +866,7 @@ func cloneBackTransfers(currentBackTransfers backTransfers) backTransfers {
857866
newBackTransfers := backTransfers{
858867
ESDTTransfers: make([]*vmcommon.ESDTTransfer, len(currentBackTransfers.ESDTTransfers)),
859868
CallValue: big.NewInt(0).Set(currentBackTransfers.CallValue),
869+
LastIndex: currentBackTransfers.LastIndex,
860870
}
861871

862872
for index, transfer := range currentBackTransfers.ESDTTransfers {

vmhost/flags.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,7 @@ const (
1414

1515
// MaskInternalDependenciesErrorsFlag defines the flag that activates masking of internal dependencies errors
1616
MaskInternalDependenciesErrorsFlag core.EnableEpochFlag = "MaskInternalDependenciesErrorsFlag"
17+
18+
// FixBackTransferOPCODE defines the flag that activates the fix for get back transfer opcode
19+
FixBackTransferOPCODE core.EnableEpochFlag = "FixBackTransferOPCODEFlag"
1720
)

vmhost/hostCore/execution.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ func (host *vmHost) addNewBackTransfersFromVMOutput(vmOutput *vmcommon.VMOutput,
359359

360360
if transfer.Value.Cmp(vmhost.Zero) > 0 {
361361
if len(transfer.Data) == 0 {
362-
host.managedTypesContext.AddValueOnlyBackTransfer(transfer.Value)
362+
host.managedTypesContext.AddBackTransfers(transfer.Value, nil, transfer.Index)
363363
}
364364
continue
365365
}
@@ -369,7 +369,7 @@ func (host *vmHost) addNewBackTransfersFromVMOutput(vmOutput *vmcommon.VMOutput,
369369
continue
370370
}
371371

372-
host.managedTypesContext.AddBackTransfers(esdtTransfers.ESDTTransfers)
372+
host.managedTypesContext.AddBackTransfers(vmhost.Zero, esdtTransfers.ESDTTransfers, transfer.Index)
373373
}
374374
}
375375

vmhost/hosttest/managedei_test.go

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1689,6 +1689,185 @@ func Test_Direct_ManagedGetBackTransfers(t *testing.T) {
16891689
assert.Nil(t, err)
16901690
}
16911691

1692+
func Test_MultipleCalls_ManagedGetBackTransfers(t *testing.T) {
1693+
testConfig := makeTestConfig()
1694+
egldBalance := big.NewInt(10)
1695+
egldTransfer := big.NewInt(1)
1696+
initialESDTTokenBalance := uint64(100)
1697+
testConfig.ESDTTokensToTransfer = 5
1698+
callsNumber := 2
1699+
1700+
_, err := test.BuildMockInstanceCallTest(t).
1701+
WithContracts(
1702+
test.CreateMockContract(test.ParentAddress).
1703+
WithBalance(testConfig.ParentBalance).
1704+
WithConfig(testConfig).
1705+
WithMethods(func(parentInstance *mock.InstanceMock, config interface{}) {
1706+
parentInstance.AddMockMethod("callChild", func() *mock.InstanceMock {
1707+
host := parentInstance.Host
1708+
1709+
for i := 0; i < callsNumber; i++ {
1710+
input := test.DefaultTestContractCallInput()
1711+
input.GasProvided = testConfig.GasProvidedToChild
1712+
input.CallerAddr = test.ParentAddress
1713+
input.RecipientAddr = test.ChildAddress
1714+
input.Function = "childFunction"
1715+
returnValue := contracts.ExecuteOnDestContextInMockContracts(host, input)
1716+
assert.Equal(t, int32(0), returnValue)
1717+
}
1718+
1719+
managedTypes := host.ManagedTypes()
1720+
esdtTransfers, egld := managedTypes.GetBackTransfers()
1721+
assert.Equal(t, callsNumber, len(esdtTransfers))
1722+
for i := 0; i < callsNumber; i++ {
1723+
assert.Equal(t, test.ESDTTestTokenName, esdtTransfers[i].ESDTTokenName)
1724+
assert.Equal(t, big.NewInt(0).SetUint64(testConfig.ESDTTokensToTransfer), esdtTransfers[i].ESDTValue)
1725+
}
1726+
assert.Equal(t, big.NewInt(egldTransfer.Int64()*int64(callsNumber)), egld)
1727+
return parentInstance
1728+
})
1729+
}),
1730+
test.CreateMockContract(test.ChildAddress).
1731+
WithBalance(testConfig.ChildBalance).
1732+
WithConfig(testConfig).
1733+
WithMethods(func(parentInstance *mock.InstanceMock, config interface{}) {
1734+
parentInstance.AddMockMethod("childFunction", func() *mock.InstanceMock {
1735+
host := parentInstance.Host
1736+
1737+
valueBytes := egldTransfer.Bytes()
1738+
err := host.Output().Transfer(
1739+
test.ParentAddress,
1740+
test.ChildAddress, 0, 0, big.NewInt(0).SetBytes(valueBytes), nil, []byte{}, vm.DirectCall)
1741+
assert.Nil(t, err)
1742+
1743+
transfer := &vmcommon.ESDTTransfer{
1744+
ESDTValue: big.NewInt(int64(testConfig.ESDTTokensToTransfer)),
1745+
ESDTTokenName: test.ESDTTestTokenName,
1746+
ESDTTokenType: 0,
1747+
ESDTTokenNonce: 0,
1748+
}
1749+
1750+
ret := vmhooks.TransferESDTNFTExecuteWithTypedArgs(
1751+
host,
1752+
test.ParentAddress,
1753+
[]*vmcommon.ESDTTransfer{transfer},
1754+
int64(testConfig.GasProvidedToChild),
1755+
nil,
1756+
nil)
1757+
assert.Equal(t, ret, int32(0))
1758+
1759+
return parentInstance
1760+
})
1761+
}),
1762+
).
1763+
WithSetup(func(host vmhost.VMHost, world *worldmock.MockWorld) {
1764+
childAccount := world.AcctMap.GetAccount(test.ChildAddress)
1765+
childAccount.SetBalance(egldBalance.Int64())
1766+
_ = childAccount.SetTokenBalanceUint64(test.ESDTTestTokenName, 0, initialESDTTokenBalance)
1767+
createMockBuiltinFunctions(t, host, world)
1768+
setZeroCodeCosts(host)
1769+
}).
1770+
WithInput(test.CreateTestContractCallInputBuilder().
1771+
WithRecipientAddr(test.ParentAddress).
1772+
WithGasProvided(testConfig.GasProvided).
1773+
WithFunction("callChild").
1774+
Build()).
1775+
AndAssertResults(func(world *worldmock.MockWorld, verify *test.VMOutputVerifier) {
1776+
verify.
1777+
Ok()
1778+
})
1779+
assert.Nil(t, err)
1780+
}
1781+
1782+
func Test_MultipleCalls_MultipleReads_ManagedGetBackTransfers(t *testing.T) {
1783+
testConfig := makeTestConfig()
1784+
egldBalance := big.NewInt(10)
1785+
egldTransfer := big.NewInt(1)
1786+
initialESDTTokenBalance := uint64(100)
1787+
testConfig.ESDTTokensToTransfer = 5
1788+
callsNumber := 2
1789+
1790+
_, err := test.BuildMockInstanceCallTest(t).
1791+
WithContracts(
1792+
test.CreateMockContract(test.ParentAddress).
1793+
WithBalance(testConfig.ParentBalance).
1794+
WithConfig(testConfig).
1795+
WithMethods(func(parentInstance *mock.InstanceMock, config interface{}) {
1796+
parentInstance.AddMockMethod("callChild", func() *mock.InstanceMock {
1797+
host := parentInstance.Host
1798+
1799+
for i := 0; i < callsNumber; i++ {
1800+
input := test.DefaultTestContractCallInput()
1801+
input.GasProvided = testConfig.GasProvidedToChild
1802+
input.CallerAddr = test.ParentAddress
1803+
input.RecipientAddr = test.ChildAddress
1804+
input.Function = "childFunction"
1805+
returnValue := contracts.ExecuteOnDestContextInMockContracts(host, input)
1806+
assert.Equal(t, int32(0), returnValue)
1807+
1808+
managedTypes := host.ManagedTypes()
1809+
esdtTransfers, egld := managedTypes.GetBackTransfers()
1810+
assert.Equal(t, 1, len(esdtTransfers))
1811+
assert.Equal(t, test.ESDTTestTokenName, esdtTransfers[0].ESDTTokenName)
1812+
assert.Equal(t, big.NewInt(0).SetUint64(testConfig.ESDTTokensToTransfer), esdtTransfers[0].ESDTValue)
1813+
assert.Equal(t, egldTransfer, egld)
1814+
}
1815+
1816+
return parentInstance
1817+
})
1818+
}),
1819+
test.CreateMockContract(test.ChildAddress).
1820+
WithBalance(testConfig.ChildBalance).
1821+
WithConfig(testConfig).
1822+
WithMethods(func(parentInstance *mock.InstanceMock, config interface{}) {
1823+
parentInstance.AddMockMethod("childFunction", func() *mock.InstanceMock {
1824+
host := parentInstance.Host
1825+
1826+
valueBytes := egldTransfer.Bytes()
1827+
err := host.Output().Transfer(
1828+
test.ParentAddress,
1829+
test.ChildAddress, 0, 0, big.NewInt(0).SetBytes(valueBytes), nil, []byte{}, vm.DirectCall)
1830+
assert.Nil(t, err)
1831+
1832+
transfer := &vmcommon.ESDTTransfer{
1833+
ESDTValue: big.NewInt(int64(testConfig.ESDTTokensToTransfer)),
1834+
ESDTTokenName: test.ESDTTestTokenName,
1835+
ESDTTokenType: 0,
1836+
ESDTTokenNonce: 0,
1837+
}
1838+
1839+
ret := vmhooks.TransferESDTNFTExecuteWithTypedArgs(
1840+
host,
1841+
test.ParentAddress,
1842+
[]*vmcommon.ESDTTransfer{transfer},
1843+
int64(testConfig.GasProvidedToChild),
1844+
nil,
1845+
nil)
1846+
assert.Equal(t, ret, int32(0))
1847+
1848+
return parentInstance
1849+
})
1850+
}),
1851+
).
1852+
WithSetup(func(host vmhost.VMHost, world *worldmock.MockWorld) {
1853+
childAccount := world.AcctMap.GetAccount(test.ChildAddress)
1854+
childAccount.SetBalance(egldBalance.Int64())
1855+
_ = childAccount.SetTokenBalanceUint64(test.ESDTTestTokenName, 0, initialESDTTokenBalance)
1856+
createMockBuiltinFunctions(t, host, world)
1857+
setZeroCodeCosts(host)
1858+
}).
1859+
WithInput(test.CreateTestContractCallInputBuilder().
1860+
WithRecipientAddr(test.ParentAddress).
1861+
WithGasProvided(testConfig.GasProvided).
1862+
WithFunction("callChild").
1863+
Build()).
1864+
AndAssertResults(func(world *worldmock.MockWorld, verify *test.VMOutputVerifier) {
1865+
verify.
1866+
Ok()
1867+
})
1868+
assert.Nil(t, err)
1869+
}
1870+
16921871
func Test_Async_ManagedGetBackTransfers(t *testing.T) {
16931872
testConfig := makeTestConfig()
16941873
initialESDTTokenBalance := uint64(100)

vmhost/interface.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,7 @@ type ManagedTypesContext interface {
221221
ManagedMapRemove(mMapHandle int32, keyHandle int32, outValueHandle int32) error
222222
ManagedMapContains(mMapHandle int32, keyHandle int32) (bool, error)
223223
GetBackTransfers() ([]*vmcommon.ESDTTransfer, *big.Int)
224-
AddValueOnlyBackTransfer(value *big.Int)
225-
AddBackTransfers(transfers []*vmcommon.ESDTTransfer)
224+
AddBackTransfers(value *big.Int, transfers []*vmcommon.ESDTTransfer, index uint32)
226225
PopBackTransferIfAsyncCallBack(vmInput *vmcommon.ContractCallInput)
227226
}
228227

0 commit comments

Comments
 (0)