Skip to content

Commit 38f73b5

Browse files
committed
Add model validation
1 parent 758b8fd commit 38f73b5

File tree

11 files changed

+191
-93
lines changed

11 files changed

+191
-93
lines changed

usvm-core/src/main/kotlin/org/usvm/ps/AIPathSelector.kt

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
package org.usvm.ps
2+
23
import org.usvm.UPathSelector
34
import org.usvm.UState
45
import org.usvm.statistics.BasicBlock
56
import org.usvm.statistics.BlockGraph
67
import org.usvm.statistics.StepsStatistics
7-
import org.usvm.util.OnnxModel
8-
import org.usvm.util.Oracle
98
import org.usvm.util.Predictor
109
import org.usvm.utils.Game
10+
import org.usvm.utils.OnnxModelImpl
1111
import org.usvm.utils.StateWrapper
1212
import org.usvm.utils.isSat
1313

@@ -17,8 +17,7 @@ class AIPathSelector<Statement, State, Block>(
1717
private val stepsStatistics: StepsStatistics<*, State>,
1818
private val predictor: Predictor<Game<Block>>,
1919
) : UPathSelector<State> where
20-
State : UState<*, *, Statement, *, *, State>,
21-
Block : BasicBlock {
20+
State : UState<*, *, Statement, *, *, State>, Block : BasicBlock {
2221
private val statesMap = mutableMapOf<State, StateWrapper<Statement, State, Block>>()
2322
private val lastPeekedState: State?
2423
get() = stepsStatistics.lastPeekedState
@@ -43,12 +42,14 @@ Block : BasicBlock {
4342
}
4443

4544
private fun buildGame(
46-
vertices: List<Block>,
47-
wrappers: MutableCollection<StateWrapper<Statement, State, Block>>
45+
vertices: List<Block>, wrappers: MutableCollection<StateWrapper<Statement, State, Block>>
4846
): Game<Block> {
49-
val game = when (predictor) {
50-
is OnnxModel<Game<Block>> -> Game(vertices, wrappers, blockGraph)
51-
is Oracle<Game<Block>> -> {
47+
val game = when {
48+
predictor is OnnxModelImpl<*> && !predictor.isTrainMode -> {
49+
Game(vertices, wrappers, blockGraph)
50+
}
51+
52+
else -> {
5253
// if we played with default searcher before,
5354
// client has no information about the game
5455
if (firstSend) {
@@ -110,8 +111,7 @@ Block : BasicBlock {
110111
* If state [isSat], some blocks may have been covered by
111112
* [org.usvm.statistics.CoverageStatistics.onStateTerminated]
112113
*/
113-
if (state.isSat())
114-
touchedBlocks.addAll(wrapper.history.keys)
114+
if (state.isSat()) touchedBlocks.addAll(wrapper.history.keys)
115115

116116
touchedStates.remove(wrapper)
117117
}
@@ -120,15 +120,9 @@ Block : BasicBlock {
120120
override fun add(states: Collection<State>) {
121121
// is null iff we are adding initial state
122122
val lastPeekedStateWrapper = statesMap[lastPeekedState]
123-
val parentPathConditionSize = lastPeekedStateWrapper?.pathConditionSize ?: 0
124123
val parentHistory = lastPeekedStateWrapper?.history ?: mutableMapOf()
125124
val wrappers = states.map { state ->
126-
val wrapper = StateWrapper(
127-
state,
128-
parentPathConditionSize,
129-
parentHistory,
130-
blockGraph
131-
)
125+
val wrapper = StateWrapper(state, parentHistory, blockGraph)
132126
statesMap[state] = wrapper
133127
wrapper
134128
}

usvm-core/src/main/kotlin/org/usvm/utils/OnnxModelImpl.kt

Lines changed: 54 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package org.usvm.utils
22

33
import ai.onnxruntime.OnnxTensor
4+
import ai.onnxruntime.OnnxTensor.createTensor
45
import ai.onnxruntime.OrtEnvironment
56
import ai.onnxruntime.OrtSession
67
import org.usvm.StateId
@@ -10,16 +11,38 @@ import org.usvm.statistics.BlockGraph
1011
import org.usvm.util.OnnxModel
1112
import java.nio.FloatBuffer
1213
import java.nio.LongBuffer
14+
import kotlin.collections.toLongArray
1315

1416
enum class Mode {
15-
CPU,
16-
GPU
17+
CPU, GPU
1718
}
1819

19-
class OnnxModelImpl<Block: BasicBlock>(
20-
pathToONNX: String,
21-
mode: Mode
22-
): OnnxModel<Game<Block>> {
20+
data class AIGameStep<T : BasicBlock>(
21+
val Game: Game<T>, val NNOutput: Array<FloatArray>
22+
) {
23+
// autogenerated
24+
override fun equals(other: Any?): Boolean {
25+
if (this === other) return true
26+
if (javaClass != other?.javaClass) return false
27+
28+
other as AIGameStep<*>
29+
30+
if (Game != other.Game) return false
31+
if (!NNOutput.contentDeepEquals(other.NNOutput)) return false
32+
33+
return true
34+
}
35+
36+
override fun hashCode(): Int {
37+
var result = Game.hashCode()
38+
result = 31 * result + NNOutput.contentDeepHashCode()
39+
return result
40+
}
41+
}
42+
43+
class OnnxModelImpl<Block : BasicBlock>(
44+
pathToONNX: String, mode: Mode, val isTrainMode: Boolean, val saveStep: (AIGameStep<Block>) -> Unit = {}
45+
) : OnnxModel<Game<Block>> {
2346
private val env: OrtEnvironment = OrtEnvironment.getEnvironment()
2447
private val sessionOptions: OrtSession.SessionOptions = OrtSession.SessionOptions().apply {
2548
if (mode == Mode.GPU) {
@@ -34,48 +57,47 @@ class OnnxModelImpl<Block: BasicBlock>(
3457
val stateIds = mutableMapOf<StateId, Int>()
3558
val input = generateInput(game, stateIds)
3659
val output = session.run(input)
37-
38-
val predictedStatesRanks =
39-
(output["out"].get().value as Array<*>).map { (it as FloatArray).toList() }
40-
60+
val predictedStates = output["out"].get().value as Array<*>
61+
val predictedStatesRanks = predictedStates.map { (it as FloatArray).toList() }
62+
saveStep(AIGameStep(game, predictedStates.map { it as FloatArray }.toTypedArray()))
4163
return checkNotNull(getPredictedState(predictedStatesRanks, stateIds))
4264
}
4365

4466
private fun generateInput(
45-
game: Game<Block>,
46-
stateIds: MutableMap<StateId, Int>
67+
game: Game<Block>, stateIds: MutableMap<StateId, Int>
4768
): Map<String, OnnxTensor> {
4869
val (vertices, stateWrappers, blockGraph) = game
4970
val vertexIds = mutableMapOf<Int, Int>()
5071
val gameVertices = tensorFromBasicBlocks(vertices, vertexIds)
5172
val states = tensorFromStates(stateWrappers, stateIds)
5273
val (vertexToVertexEdgesIndex, vertexToVertexEdgesAttributes) = tensorFromVertexEdges(
53-
vertices,
54-
blockGraph,
55-
vertexIds
74+
vertices, blockGraph, vertexIds
5675
)
5776
val (parentOfEdges, historyEdgesIndexVertexToState, historyEdgesAttributes) = tensorFromStateEdges(
58-
stateWrappers,
59-
stateIds,
60-
vertexIds
77+
stateWrappers, stateIds, vertexIds
6178
)
6279
val vertexToState = tensorFromStatePositions(stateWrappers, stateIds, vertexIds)
63-
80+
val mockPC1 = createTensor(env, FloatBuffer.wrap(FloatArray(0)), longArrayOf(0, 49))
81+
val mockPC2 = createTensor(env, LongBuffer.wrap(LongArray(0)), longArrayOf(2, 0))
82+
val mockPC3 =
83+
createTensor(env, LongBuffer.wrap(LongArray(0)), longArrayOf(2, 0))
6484
return mapOf(
6585
"game_vertex" to gameVertices,
6686
"state_vertex" to states,
87+
"path_condition_vertex" to mockPC1,
6788
"gamevertex_to_gamevertex_index" to vertexToVertexEdgesIndex,
6889
"gamevertex_to_gamevertex_type" to vertexToVertexEdgesAttributes,
6990
"gamevertex_history_statevertex_index" to historyEdgesIndexVertexToState,
7091
"gamevertex_history_statevertex_attrs" to historyEdgesAttributes,
7192
"gamevertex_in_statevertex" to vertexToState,
72-
"statevertex_parentof_statevertex" to parentOfEdges
93+
"statevertex_parentof_statevertex" to parentOfEdges,
94+
"pathconditionvertex_to_pathconditionvertex" to mockPC2,
95+
"pathconditionvertex_to_statevertex" to mockPC3,
7396
)
7497
}
7598

7699
private fun getPredictedState(stateRank: List<List<Float>>, stateIds: Map<StateId, Int>): StateId? {
77-
return stateRank
78-
.mapIndexed { index, ranks -> stateIds.entries.find { it.value == index }?.key to ranks.sum() }
100+
return stateRank.mapIndexed { index, ranks -> stateIds.entries.find { it.value == index }?.key to ranks.sum() }
79101
.maxBy { it.second }.first
80102
}
81103

@@ -97,13 +119,14 @@ class OnnxModelImpl<Block: BasicBlock>(
97119
return createFloatTensor(verticesArray, vertices.size, numOfVertexAttributes)
98120
}
99121

100-
private fun tensorFromStates(states: Collection<StateWrapper<*, *, *>>, stateIds: MutableMap<StateId, Int>): OnnxTensor {
101-
val numOfStateAttributes = 7
122+
private fun tensorFromStates(
123+
states: Collection<StateWrapper<*, *, *>>, stateIds: MutableMap<StateId, Int>
124+
): OnnxTensor {
125+
val numOfStateAttributes = 6
102126
val statesArray = states.flatMapIndexed { i, state ->
103127
stateIds[state.id] = i
104128
listOf(
105129
state.position,
106-
state.pathConditionSize,
107130
state.visitedAgainVertices,
108131
state.visitedNotCoveredVerticesInZone,
109132
state.visitedNotCoveredVerticesOutOfZone,
@@ -116,9 +139,7 @@ class OnnxModelImpl<Block: BasicBlock>(
116139
}
117140

118141
private fun tensorFromVertexEdges(
119-
blocks: Collection<Block>,
120-
blockGraph: BlockGraph<*, Block, *>,
121-
vertexIds: MutableMap<Int, Int>
142+
blocks: Collection<Block>, blockGraph: BlockGraph<*, Block, *>, vertexIds: MutableMap<Int, Int>
122143
): Pair<OnnxTensor, OnnxTensor> {
123144
val vertexFrom = mutableListOf<Long>()
124145
val vertexTo = mutableListOf<Long>()
@@ -136,15 +157,12 @@ class OnnxModelImpl<Block: BasicBlock>(
136157
val indexList = (vertexFrom + vertexTo).toLongArray()
137158

138159
return createLongTensor(indexList, 2, vertexFrom.size) to createLongTensor(
139-
attributes.toLongArray(),
140-
attributes.size
160+
attributes.toLongArray(), attributes.size
141161
)
142162
}
143163

144164
private fun tensorFromStateEdges(
145-
states: Collection<StateWrapper<*, *, *>>,
146-
stateIds: MutableMap<StateId, Int>,
147-
vertexIds: MutableMap<Int, Int>
165+
states: Collection<StateWrapper<*, *, *>>, stateIds: MutableMap<StateId, Int>, vertexIds: MutableMap<Int, Int>
148166
): Triple<OnnxTensor, OnnxTensor, OnnxTensor> {
149167
val numOfParentOfEdges = states.sumOf { it.children.size }
150168
val numOfHistoryEdges = states.sumOf { it.history.size }
@@ -173,20 +191,15 @@ class OnnxModelImpl<Block: BasicBlock>(
173191
}
174192

175193
val parentTensor = createLongTensor(parentOf, 2, numOfParentOfEdges)
176-
val historyIndexTensor =
177-
createLongTensor(historyIndexVertexToState, 2, numOfHistoryEdges)
194+
val historyIndexTensor = createLongTensor(historyIndexVertexToState, 2, numOfHistoryEdges)
178195
val historyAttributesTensor = createLongTensor(
179-
historyAttributes,
180-
numOfHistoryEdges,
181-
numOfHistoryEdgeAttributes
196+
historyAttributes, numOfHistoryEdges, numOfHistoryEdgeAttributes
182197
)
183198
return Triple(parentTensor, historyIndexTensor, historyAttributesTensor)
184199
}
185200

186201
private fun tensorFromStatePositions(
187-
states: Collection<StateWrapper<*, *, *>>,
188-
stateIds: MutableMap<StateId, Int>,
189-
vertexIds: MutableMap<Int, Int>
202+
states: Collection<StateWrapper<*, *, *>>, stateIds: MutableMap<StateId, Int>, vertexIds: MutableMap<Int, Int>
190203
): OnnxTensor {
191204
val totalStates = states.size
192205
val vertexToState = LongArray(2 * totalStates)

usvm-core/src/main/kotlin/org/usvm/utils/StateWrapper.kt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ data class StateHistoryElement(
1313

1414
class StateWrapper<Statement, State, Block>(
1515
private val state: State,
16-
private val parentPathConditionSize: Int,
1716
private val parentHistory: MutableMap<Block, StateHistoryElement>,
1817
private val blockGraph: BlockGraph<*, Block, Statement>,
1918
) where State : UState<*, *, Statement, *, *, State>, Block : BasicBlock {
@@ -24,7 +23,6 @@ class StateWrapper<Statement, State, Block>(
2423
private var visitedStatement: Statement? = null
2524
lateinit var currentBlock: Block
2625
var position by Delegates.notNull<Int>()
27-
var pathConditionSize by Delegates.notNull<Int>()
2826
var visitedAgainVertices by Delegates.notNull<Int>()
2927
var visitedNotCoveredVerticesInZone by Delegates.notNull<Int>()
3028
var visitedNotCoveredVerticesOutOfZone by Delegates.notNull<Int>()
@@ -42,7 +40,6 @@ class StateWrapper<Statement, State, Block>(
4240
currentBlock.states.add(this@StateWrapper.id)
4341

4442
position = currentBlock.id
45-
pathConditionSize = parentPathConditionSize + state.forkPoints.depth
4643
instructionsVisitedInCurrentBlock = 0
4744
}
4845
instructionsVisitedInCurrentBlock++

usvm-ml-gameserver/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ plugins {
55
id(Plugins.Ktor)
66
kotlin("plugin.serialization") version Versions.kotlin
77
application
8+
kotlin
89
}
910

1011
application {

0 commit comments

Comments
 (0)