|
| 1 | +/* |
| 2 | + * The Certora Prover |
| 3 | + * Copyright (C) 2025 Certora Ltd. |
| 4 | + * |
| 5 | + * This program is free software: you can redistribute it and/or modify |
| 6 | + * it under the terms of the GNU General Public License as published by |
| 7 | + * the Free Software Foundation, version 3 of the License. |
| 8 | + * |
| 9 | + * This program is distributed in the hope that it will be useful, |
| 10 | + * but WITHOUT ANY WARRANTY; without even the implied warranty of |
| 11 | + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
| 12 | + * GNU General Public License for more details. |
| 13 | + * |
| 14 | + * You should have received a copy of the GNU General Public License |
| 15 | + * along with this program. If not, see <https://www.gnu.org/licenses/>. |
| 16 | + */ |
| 17 | + |
| 18 | +package sbf.cfg |
| 19 | + |
| 20 | +import log.Logger |
| 21 | +import log.LoggerTypes |
| 22 | +import sbf.SolanaConfig |
| 23 | +import sbf.disassembler.SbfRegister |
| 24 | +import sbf.domains.FiniteInterval |
| 25 | +import sbf.domains.SetOfFiniteIntervals |
| 26 | +import kotlin.math.absoluteValue |
| 27 | + |
| 28 | +private val logger = Logger(LoggerTypes.SBF_MEMCPY_PROMOTION) |
| 29 | +private fun dbg(msg: () -> Any) { logger.debug(msg)} |
| 30 | + |
| 31 | +/** |
| 32 | + * Represent a memcpy pattern, i.e., sequence of load/store pairs |
| 33 | + **/ |
| 34 | +class MemcpyPattern { |
| 35 | + /** Class invariants: |
| 36 | + * 0. `loads.size == stores.size` |
| 37 | + * 1. `forall 0 <= i < size-1 :: distance(loads[i ],loads[i+1]) == distance(stores[i ], stores[i+1])` |
| 38 | + * 2. `forall i,j, i!=j :: loads[i ].region == loads[j ].region && stores[i ].region == stores[j ].region` |
| 39 | + * 3. `forall i,j, i!=j :: loads[i ].reg == loads[j ].reg && stores[i ].reg == stores[j ].reg` |
| 40 | + **/ |
| 41 | + private val loads = mutableListOf<MemAccess>() |
| 42 | + private val stores = mutableListOf<MemAccess>() |
| 43 | + private val loadLocInsts = mutableListOf<LocatedSbfInstruction>() |
| 44 | + private val storeLocInsts = mutableListOf<LocatedSbfInstruction>() |
| 45 | + |
| 46 | + fun getStores(): List<LocatedSbfInstruction> = storeLocInsts |
| 47 | + fun getLoads(): List<LocatedSbfInstruction> = loadLocInsts |
| 48 | + |
| 49 | + /** |
| 50 | + * Return false if it cannot maintain all class invariants (See above). |
| 51 | + * Otherwise, it will add the [load]/[store] pair as part of a memcpy. |
| 52 | + **/ |
| 53 | + fun add(load: MemAccess, |
| 54 | + loadLocInst: LocatedSbfInstruction, |
| 55 | + store: MemAccess, |
| 56 | + storeLocInst: LocatedSbfInstruction |
| 57 | + ): Boolean { |
| 58 | + check(loadLocInst.label == storeLocInst.label) |
| 59 | + {"can only promote pairs of load-store within the same block"} |
| 60 | + |
| 61 | + if (loads.isEmpty()) { |
| 62 | + loads.add(load) |
| 63 | + stores.add(store) |
| 64 | + loadLocInsts.add(loadLocInst) |
| 65 | + storeLocInsts.add(storeLocInst) |
| 66 | + return true |
| 67 | + } else { |
| 68 | + val lastLoad = loads.last() |
| 69 | + val lastStore = stores.last() |
| 70 | + // class invariant #3 |
| 71 | + if (lastLoad.reg != load.reg || lastStore.reg != store.reg) { |
| 72 | + return false |
| 73 | + } |
| 74 | + // class invariant #2 |
| 75 | + if (lastLoad.region != load.region || lastStore.region != store.region) { |
| 76 | + return false |
| 77 | + } |
| 78 | + // class invariant #1 |
| 79 | + val loadDiff = (load.offset - lastLoad.offset).absoluteValue |
| 80 | + val storeDiff = (store.offset - lastStore.offset).absoluteValue |
| 81 | + return if (loadDiff == storeDiff) { |
| 82 | + loads.add(load) |
| 83 | + stores.add(store) |
| 84 | + loadLocInsts.add(loadLocInst) |
| 85 | + storeLocInsts.add(storeLocInst) |
| 86 | + true |
| 87 | + } else { |
| 88 | + false |
| 89 | + } |
| 90 | + } |
| 91 | + } |
| 92 | + |
| 93 | + data class MemcpyArgs( |
| 94 | + val srcReg: SbfRegister, |
| 95 | + val srcStart: Long, |
| 96 | + val dstReg: SbfRegister, |
| 97 | + val dstStart: Long, |
| 98 | + val size: ULong, |
| 99 | + val metadata: MetaData |
| 100 | + ) |
| 101 | + |
| 102 | + /** |
| 103 | + * Return non-null if |
| 104 | + * (1) source and destination do not overlap and |
| 105 | + * (2) the sequence of loads and stores accesses memory in the same ordering (decreasing or increasing) and |
| 106 | + * (3) the sequences form a consecutive range of bytes without holes in between. |
| 107 | + */ |
| 108 | + fun canBePromoted(minSizeToBePromoted: ULong): MemcpyArgs? { |
| 109 | + val name = "canBePromoted" |
| 110 | + check(loads.size == stores.size) { |
| 111 | + "$name expects same number of loads and stores: $loads and $stores" |
| 112 | + } |
| 113 | + check(loadLocInsts.size == storeLocInsts.size) { |
| 114 | + "$name expects same number of loads and stores: $loadLocInsts and $storeLocInsts" |
| 115 | + } |
| 116 | + check(loads.size == loadLocInsts.size) { |
| 117 | + "$name: $loads and $loadLocInsts should have same size" |
| 118 | + } |
| 119 | + check(stores.size == storeLocInsts.size) { |
| 120 | + "$name: $stores and $storeLocInsts should have same size" |
| 121 | + } |
| 122 | + |
| 123 | + // Ensure that no overlaps between source and destination |
| 124 | + fun noOverlaps(srcRegion: MemAccessRegion, srcStart: Long, |
| 125 | + dstRegion: MemAccessRegion, dstStart: Long, |
| 126 | + size: ULong): Boolean { |
| 127 | + |
| 128 | + if (srcRegion == MemAccessRegion.STACK && dstRegion == MemAccessRegion.STACK) { |
| 129 | + val i1 = FiniteInterval.mkInterval(srcStart, size.toLong()) |
| 130 | + val i2 = FiniteInterval.mkInterval(dstStart, size.toLong()) |
| 131 | + return (!i1.overlap(i2)) |
| 132 | + } else if (srcRegion != dstRegion && srcRegion != MemAccessRegion.ANY && dstRegion != MemAccessRegion.ANY) { |
| 133 | + return true |
| 134 | + } else { |
| 135 | + return if (SolanaConfig.optimisticMemcpyPromotion()) { |
| 136 | + dbg { |
| 137 | + "$name: we cannot prove that no overlaps between $loadLocInsts and $storeLocInsts" |
| 138 | + } |
| 139 | + true |
| 140 | + } else { |
| 141 | + false |
| 142 | + } |
| 143 | + } |
| 144 | + } |
| 145 | + |
| 146 | + /** |
| 147 | + * Find a single interval for all loads and another single interval for all stores. |
| 148 | + * If it cannot then it removes the last inserted load and store and try again. |
| 149 | + * This is a greedy approach, so it's not optimal. |
| 150 | + */ |
| 151 | + fun getRangeForLoadsAndStores(): Pair<FiniteInterval, FiniteInterval>? { |
| 152 | + while (loads.isNotEmpty()) { // this loop is needed by test24 |
| 153 | + var srcIntervals = SetOfFiniteIntervals.new() |
| 154 | + var dstIntervals = SetOfFiniteIntervals.new() |
| 155 | + var prevLoad: MemAccess? = null |
| 156 | + var prevStore: MemAccess? = null |
| 157 | + for ((load, store) in loads.zip(stores)) { |
| 158 | + if (prevLoad != null && prevStore != null) { |
| 159 | + val loadDist = load.offset - prevLoad.offset |
| 160 | + val storeDist = store.offset - prevStore.offset |
| 161 | + if (loadDist != storeDist) { |
| 162 | + // loads and stores have different ordering, so it's not a memcpy |
| 163 | + // Note that we completely give up here even if we could also try smaller |
| 164 | + // number of pairs of load-store. |
| 165 | + return null |
| 166 | + } |
| 167 | + } |
| 168 | + srcIntervals = srcIntervals.add(FiniteInterval.mkInterval(load.offset, load.width.toLong())) |
| 169 | + dstIntervals = dstIntervals.add(FiniteInterval.mkInterval(store.offset, store.width.toLong())) |
| 170 | + |
| 171 | + prevLoad = load |
| 172 | + prevStore = store |
| 173 | + } |
| 174 | + val srcSingleton = srcIntervals.getSingleton() |
| 175 | + val dstSingleton = dstIntervals.getSingleton() |
| 176 | + if (srcSingleton != null && dstSingleton != null) { |
| 177 | + return srcSingleton to dstSingleton |
| 178 | + } else { |
| 179 | + // Before we return null we remove the last inserted pair and try again |
| 180 | + loads.removeLast() |
| 181 | + stores.removeLast() |
| 182 | + loadLocInsts.removeLast() |
| 183 | + storeLocInsts.removeLast() |
| 184 | + } |
| 185 | + } |
| 186 | + return null |
| 187 | + } |
| 188 | + |
| 189 | + if (loads.isEmpty()) { |
| 190 | + return null |
| 191 | + } |
| 192 | + val p = getRangeForLoadsAndStores() ?: return null |
| 193 | + val srcRange = p.first |
| 194 | + val dstRange = p.second |
| 195 | + return if (srcRange.size() == dstRange.size() && srcRange.size() >= minSizeToBePromoted) { |
| 196 | + val srcReg = loads.first().reg |
| 197 | + val dstReg = stores.first().reg |
| 198 | + val srcRegion = loads.first().region |
| 199 | + val dstRegion = stores.first().region |
| 200 | + // We will use the metadata of the first load as metadata of the promoted memcpy |
| 201 | + val metadata = loadLocInsts.first().inst.metaData |
| 202 | + if (noOverlaps(srcRegion, srcRange.l, dstRegion, dstRange.l, srcRange.size())) { |
| 203 | + MemcpyArgs(srcReg, srcRange.l, dstReg, dstRange.l, srcRange.size(), metadata) |
| 204 | + } else { |
| 205 | + null |
| 206 | + } |
| 207 | + } else { |
| 208 | + null |
| 209 | + } |
| 210 | + } |
| 211 | +} |
0 commit comments