Skip to content

Commit e9cea1e

Browse files
committed
refactor(solana): split PromoteStoresToMemcpy into multiple files (#8336)
b923324d63bc0fdefeff0007c11ca3b7047abc49
1 parent 3ce25ba commit e9cea1e

File tree

12 files changed

+1327
-1239
lines changed

12 files changed

+1327
-1239
lines changed

src/main/kotlin/sbf/cfg/PointerOptimizations.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ fun runSimplePTAOptimizations(cfg: MutableSbfCFG, globals: GlobalVariables) {
5858
fun runPTAOptimizations(prog: SbfCallGraph, memSummaries: MemorySummaries): SbfCallGraph {
5959
return prog.transformSingleEntry { entryCFG ->
6060
val optEntryCFG = entryCFG.clone(entryCFG.getName())
61-
promoteStoresToMemcpy(optEntryCFG, prog.getGlobals(), memSummaries)
61+
promoteMemcpy(optEntryCFG, prog.getGlobals(), memSummaries)
6262
removeUselessDefinitions(optEntryCFG)
6363
promoteMemset(optEntryCFG, prog.getGlobals(), memSummaries)
6464
markLoadedAsNumForPTA(optEntryCFG)
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
/*
2+
* The Certora Prover
3+
* Copyright (C) 2025 Certora Ltd.
4+
*
5+
* This program is free software: you can redistribute it and/or modify
6+
* it under the terms of the GNU General Public License as published by
7+
* the Free Software Foundation, version 3 of the License.
8+
*
9+
* This program is distributed in the hope that it will be useful,
10+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
11+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12+
* GNU General Public License for more details.
13+
*
14+
* You should have received a copy of the GNU General Public License
15+
* along with this program. If not, see <https://www.gnu.org/licenses/>.
16+
*/
17+
18+
package sbf.cfg
19+
20+
import datastructures.stdcollections.*
21+
22+
fun emitMemcpyVariant(
23+
srcWidth: ULong,
24+
dstWidth: ULong,
25+
srcMemAccess: MemAccess,
26+
dstMemAccess: MemAccess,
27+
metaData: MetaData
28+
): List<SbfInstruction>? = when {
29+
srcWidth == dstWidth ->
30+
emitMemcpy(srcMemAccess.reg, srcMemAccess.offset, dstMemAccess.reg, dstMemAccess.offset, srcWidth, metaData)
31+
srcWidth == 8UL && dstWidth < 8UL ->
32+
emitMemcpyTrunc(srcMemAccess.reg, srcMemAccess.offset, dstMemAccess.reg, dstMemAccess.offset, dstWidth, metaData)
33+
dstWidth == 8UL && srcWidth < 8UL ->
34+
emitMemcpyZExt(srcMemAccess.reg, srcMemAccess.offset, dstMemAccess.reg, dstMemAccess.offset, srcWidth, metaData)
35+
else -> null
36+
}
37+
38+
/**
39+
* @property loads The original loads
40+
* @property stores The original stores
41+
* @property insts The new code that will replace [loads] and [stores]
42+
*
43+
* [loads] must be in the same basic block.
44+
* [stores] must be in the same basic block.
45+
* However [loads] and [stores] can be in different blocks.
46+
*/
47+
data class MemcpyRewrite(
48+
val loads: List<LocatedSbfInstruction>,
49+
val stores: List<LocatedSbfInstruction>,
50+
val insts: List<SbfInstruction>
51+
) {
52+
init {
53+
check(loads.isNotEmpty())
54+
check(stores.isNotEmpty())
55+
check(loads.size == stores.size)
56+
// All loads must be in the same block
57+
check(loads.map { it.label }.distinct().size == 1) {
58+
"All loads must be in the same basic block"
59+
}
60+
// All stores must be in the same block
61+
check(stores.map { it.label }.distinct().size == 1) {
62+
"All stores must be in the same basic block"
63+
}
64+
}
65+
}
66+
67+
/**
68+
* Apply each rewrite from [rewrites].
69+
*
70+
* Stores are removed but loads are left intact because they can be used by other instructions.
71+
* Subsequent optimizations will remove the load instructions if they are dead.
72+
*
73+
* Recall that all loads (stores) must be in the same basic block, **but** loads and stores
74+
* can be in different blocks.
75+
**/
76+
fun applyRewrites(cfg: MutableSbfCFG, rewrites: List<MemcpyRewrite>) {
77+
val sortedRewrites = sortRewrites(rewrites)
78+
79+
if (sortedRewrites.isEmpty()) {
80+
return
81+
}
82+
83+
// Add metadata to all load and store instructions
84+
for (rewrite in sortedRewrites) {
85+
val loadBlock = checkNotNull(cfg.getMutableBlock(rewrite.loads.first().label))
86+
for (loadLocInst in rewrite.loads) {
87+
addMemcpyPromotionAnnotation(loadBlock, loadLocInst)
88+
}
89+
90+
val storeBlock = checkNotNull(cfg.getMutableBlock(rewrite.stores.first().label))
91+
for (storeLocInst in rewrite.stores) {
92+
addMemcpyPromotionAnnotation(storeBlock, storeLocInst)
93+
}
94+
}
95+
96+
// Group rewrites by the block containing their loads
97+
val rewritesByBlock = sortedRewrites.groupBy { it.loads.first().label }
98+
99+
// Add memcpy instructions block by block
100+
// We need to add the memcpy instructions before the first load.
101+
// For an explanation, see test13 in PromoteStoresToMemcpyTest.kt
102+
for ((label, blockRewrites) in rewritesByBlock) {
103+
val block = checkNotNull(cfg.getMutableBlock(label))
104+
var numAdded = 0
105+
for (rewrite in blockRewrites) {
106+
val loads = rewrite.loads.sortedBy { it.pos }
107+
val firstLoad = loads.first()
108+
val insertPoint = firstLoad.pos + numAdded
109+
numAdded += rewrite.insts.size
110+
block.addAll(insertPoint, rewrite.insts)
111+
}
112+
}
113+
114+
115+
// Collect block labels that contain stores
116+
val storeBlockLabels = sortedRewrites.map { it.stores.first().label }.toSet()
117+
118+
// Finally, remove the store instructions marked with `MEMCPY_PROMOTION` metadata
119+
// We scan all blocks to find annotated stores
120+
for (label in storeBlockLabels) {
121+
val toRemove = ArrayList<LocatedSbfInstruction>()
122+
val block = checkNotNull(cfg.getMutableBlock(label))
123+
for (locInst in block.getLocatedInstructions()) {
124+
val inst = locInst.inst
125+
if (inst is SbfInstruction.Mem && !inst.isLoad &&
126+
inst.metaData.getVal(SbfMeta.MEMCPY_PROMOTION) != null) {
127+
toRemove.add(locInst)
128+
}
129+
}
130+
131+
for ((numRemoved, locInst) in toRemove.withIndex()) {
132+
val adjPos = locInst.pos - numRemoved
133+
val inst = block.getInstruction(adjPos)
134+
check(inst is SbfInstruction.Mem && !inst.isLoad) {
135+
"applyRewrites expects a store instruction"
136+
}
137+
block.removeAt(adjPos)
138+
}
139+
}
140+
}
141+
142+
/**
143+
* Sort [rewrites] by the position of their first load in the block.
144+
*
145+
* This simplifies adjusting insertion points as we insert the emitted code.
146+
* Without sorting, tracking insertion point adjustments would be unnecessarily complicated.
147+
*/
148+
private fun sortRewrites(rewrites: List<MemcpyRewrite>): List<MemcpyRewrite> {
149+
// Group rewrites by the block containing their loads
150+
val rewritesByBlock = rewrites.groupBy { it.loads.first().label }
151+
152+
// Sort rewrites within each block by position of first load
153+
val sortedRewrites = mutableListOf<MemcpyRewrite>()
154+
155+
for ((_, blockRewrites) in rewritesByBlock) {
156+
val sorted = blockRewrites.sortedBy { rewrite ->
157+
rewrite.loads.minOf { it.pos }
158+
}
159+
sortedRewrites.addAll(sorted)
160+
}
161+
162+
return sortedRewrites
163+
}
164+
165+
private fun addMemcpyPromotionAnnotation(bb: MutableSbfBasicBlock, locInst: LocatedSbfInstruction) {
166+
val inst = locInst.inst
167+
if (inst is SbfInstruction.Mem) {
168+
val newMetaData = inst.metaData + SbfMeta.MEMCPY_PROMOTION()
169+
val newInst = inst.copy(metaData = newMetaData)
170+
bb.replaceInstruction(locInst.pos, newInst)
171+
}
172+
}
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
/*
2+
* The Certora Prover
3+
* Copyright (C) 2025 Certora Ltd.
4+
*
5+
* This program is free software: you can redistribute it and/or modify
6+
* it under the terms of the GNU General Public License as published by
7+
* the Free Software Foundation, version 3 of the License.
8+
*
9+
* This program is distributed in the hope that it will be useful,
10+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
11+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12+
* GNU General Public License for more details.
13+
*
14+
* You should have received a copy of the GNU General Public License
15+
* along with this program. If not, see <https://www.gnu.org/licenses/>.
16+
*/
17+
18+
package sbf.cfg
19+
20+
import log.Logger
21+
import log.LoggerTypes
22+
import sbf.SolanaConfig
23+
import sbf.disassembler.SbfRegister
24+
import sbf.domains.FiniteInterval
25+
import sbf.domains.SetOfFiniteIntervals
26+
import kotlin.math.absoluteValue
27+
28+
private val logger = Logger(LoggerTypes.SBF_MEMCPY_PROMOTION)
29+
private fun dbg(msg: () -> Any) { logger.debug(msg)}
30+
31+
/**
32+
* Represent a memcpy pattern, i.e., sequence of load/store pairs
33+
**/
34+
class MemcpyPattern {
35+
/** Class invariants:
36+
* 0. `loads.size == stores.size`
37+
* 1. `forall 0 <= i < size-1 :: distance(loads[i ],loads[i+1]) == distance(stores[i ], stores[i+1])`
38+
* 2. `forall i,j, i!=j :: loads[i ].region == loads[j ].region && stores[i ].region == stores[j ].region`
39+
* 3. `forall i,j, i!=j :: loads[i ].reg == loads[j ].reg && stores[i ].reg == stores[j ].reg`
40+
**/
41+
private val loads = mutableListOf<MemAccess>()
42+
private val stores = mutableListOf<MemAccess>()
43+
private val loadLocInsts = mutableListOf<LocatedSbfInstruction>()
44+
private val storeLocInsts = mutableListOf<LocatedSbfInstruction>()
45+
46+
fun getStores(): List<LocatedSbfInstruction> = storeLocInsts
47+
fun getLoads(): List<LocatedSbfInstruction> = loadLocInsts
48+
49+
/**
50+
* Return false if it cannot maintain all class invariants (See above).
51+
* Otherwise, it will add the [load]/[store] pair as part of a memcpy.
52+
**/
53+
fun add(load: MemAccess,
54+
loadLocInst: LocatedSbfInstruction,
55+
store: MemAccess,
56+
storeLocInst: LocatedSbfInstruction
57+
): Boolean {
58+
check(loadLocInst.label == storeLocInst.label)
59+
{"can only promote pairs of load-store within the same block"}
60+
61+
if (loads.isEmpty()) {
62+
loads.add(load)
63+
stores.add(store)
64+
loadLocInsts.add(loadLocInst)
65+
storeLocInsts.add(storeLocInst)
66+
return true
67+
} else {
68+
val lastLoad = loads.last()
69+
val lastStore = stores.last()
70+
// class invariant #3
71+
if (lastLoad.reg != load.reg || lastStore.reg != store.reg) {
72+
return false
73+
}
74+
// class invariant #2
75+
if (lastLoad.region != load.region || lastStore.region != store.region) {
76+
return false
77+
}
78+
// class invariant #1
79+
val loadDiff = (load.offset - lastLoad.offset).absoluteValue
80+
val storeDiff = (store.offset - lastStore.offset).absoluteValue
81+
return if (loadDiff == storeDiff) {
82+
loads.add(load)
83+
stores.add(store)
84+
loadLocInsts.add(loadLocInst)
85+
storeLocInsts.add(storeLocInst)
86+
true
87+
} else {
88+
false
89+
}
90+
}
91+
}
92+
93+
data class MemcpyArgs(
94+
val srcReg: SbfRegister,
95+
val srcStart: Long,
96+
val dstReg: SbfRegister,
97+
val dstStart: Long,
98+
val size: ULong,
99+
val metadata: MetaData
100+
)
101+
102+
/**
103+
* Return non-null if
104+
* (1) source and destination do not overlap and
105+
* (2) the sequence of loads and stores accesses memory in the same ordering (decreasing or increasing) and
106+
* (3) the sequences form a consecutive range of bytes without holes in between.
107+
*/
108+
fun canBePromoted(minSizeToBePromoted: ULong): MemcpyArgs? {
109+
val name = "canBePromoted"
110+
check(loads.size == stores.size) {
111+
"$name expects same number of loads and stores: $loads and $stores"
112+
}
113+
check(loadLocInsts.size == storeLocInsts.size) {
114+
"$name expects same number of loads and stores: $loadLocInsts and $storeLocInsts"
115+
}
116+
check(loads.size == loadLocInsts.size) {
117+
"$name: $loads and $loadLocInsts should have same size"
118+
}
119+
check(stores.size == storeLocInsts.size) {
120+
"$name: $stores and $storeLocInsts should have same size"
121+
}
122+
123+
// Ensure that no overlaps between source and destination
124+
fun noOverlaps(srcRegion: MemAccessRegion, srcStart: Long,
125+
dstRegion: MemAccessRegion, dstStart: Long,
126+
size: ULong): Boolean {
127+
128+
if (srcRegion == MemAccessRegion.STACK && dstRegion == MemAccessRegion.STACK) {
129+
val i1 = FiniteInterval.mkInterval(srcStart, size.toLong())
130+
val i2 = FiniteInterval.mkInterval(dstStart, size.toLong())
131+
return (!i1.overlap(i2))
132+
} else if (srcRegion != dstRegion && srcRegion != MemAccessRegion.ANY && dstRegion != MemAccessRegion.ANY) {
133+
return true
134+
} else {
135+
return if (SolanaConfig.optimisticMemcpyPromotion()) {
136+
dbg {
137+
"$name: we cannot prove that no overlaps between $loadLocInsts and $storeLocInsts"
138+
}
139+
true
140+
} else {
141+
false
142+
}
143+
}
144+
}
145+
146+
/**
147+
* Find a single interval for all loads and another single interval for all stores.
148+
* If it cannot then it removes the last inserted load and store and try again.
149+
* This is a greedy approach, so it's not optimal.
150+
*/
151+
fun getRangeForLoadsAndStores(): Pair<FiniteInterval, FiniteInterval>? {
152+
while (loads.isNotEmpty()) { // this loop is needed by test24
153+
var srcIntervals = SetOfFiniteIntervals.new()
154+
var dstIntervals = SetOfFiniteIntervals.new()
155+
var prevLoad: MemAccess? = null
156+
var prevStore: MemAccess? = null
157+
for ((load, store) in loads.zip(stores)) {
158+
if (prevLoad != null && prevStore != null) {
159+
val loadDist = load.offset - prevLoad.offset
160+
val storeDist = store.offset - prevStore.offset
161+
if (loadDist != storeDist) {
162+
// loads and stores have different ordering, so it's not a memcpy
163+
// Note that we completely give up here even if we could also try smaller
164+
// number of pairs of load-store.
165+
return null
166+
}
167+
}
168+
srcIntervals = srcIntervals.add(FiniteInterval.mkInterval(load.offset, load.width.toLong()))
169+
dstIntervals = dstIntervals.add(FiniteInterval.mkInterval(store.offset, store.width.toLong()))
170+
171+
prevLoad = load
172+
prevStore = store
173+
}
174+
val srcSingleton = srcIntervals.getSingleton()
175+
val dstSingleton = dstIntervals.getSingleton()
176+
if (srcSingleton != null && dstSingleton != null) {
177+
return srcSingleton to dstSingleton
178+
} else {
179+
// Before we return null we remove the last inserted pair and try again
180+
loads.removeLast()
181+
stores.removeLast()
182+
loadLocInsts.removeLast()
183+
storeLocInsts.removeLast()
184+
}
185+
}
186+
return null
187+
}
188+
189+
if (loads.isEmpty()) {
190+
return null
191+
}
192+
val p = getRangeForLoadsAndStores() ?: return null
193+
val srcRange = p.first
194+
val dstRange = p.second
195+
return if (srcRange.size() == dstRange.size() && srcRange.size() >= minSizeToBePromoted) {
196+
val srcReg = loads.first().reg
197+
val dstReg = stores.first().reg
198+
val srcRegion = loads.first().region
199+
val dstRegion = stores.first().region
200+
// We will use the metadata of the first load as metadata of the promoted memcpy
201+
val metadata = loadLocInsts.first().inst.metaData
202+
if (noOverlaps(srcRegion, srcRange.l, dstRegion, dstRange.l, srcRange.size())) {
203+
MemcpyArgs(srcReg, srcRange.l, dstReg, dstRange.l, srcRange.size(), metadata)
204+
} else {
205+
null
206+
}
207+
} else {
208+
null
209+
}
210+
}
211+
}

0 commit comments

Comments
 (0)