11package org.usvm.utils
22
33import ai.onnxruntime.OnnxTensor
4+ import ai.onnxruntime.OnnxTensor.createTensor
45import ai.onnxruntime.OrtEnvironment
56import ai.onnxruntime.OrtSession
67import org.usvm.StateId
@@ -10,16 +11,38 @@ import org.usvm.statistics.BlockGraph
1011import org.usvm.util.OnnxModel
1112import java.nio.FloatBuffer
1213import java.nio.LongBuffer
14+ import kotlin.collections.toLongArray
1315
1416enum class Mode {
15- CPU ,
16- GPU
17+ CPU , GPU
1718}
1819
19- class OnnxModelImpl <Block : BasicBlock >(
20- pathToONNX : String ,
21- mode : Mode
22- ): OnnxModel<Game<Block>> {
20+ data class AIGameStep <T : BasicBlock >(
21+ val Game : Game <T >, val NNOutput : Array <FloatArray >
22+ ) {
23+ // autogenerated
24+ override fun equals (other : Any? ): Boolean {
25+ if (this == = other) return true
26+ if (javaClass != other?.javaClass) return false
27+
28+ other as AIGameStep <* >
29+
30+ if (Game != other.Game ) return false
31+ if (! NNOutput .contentDeepEquals(other.NNOutput )) return false
32+
33+ return true
34+ }
35+
36+ override fun hashCode (): Int {
37+ var result = Game .hashCode()
38+ result = 31 * result + NNOutput .contentDeepHashCode()
39+ return result
40+ }
41+ }
42+
43+ class OnnxModelImpl <Block : BasicBlock >(
44+ pathToONNX : String , mode : Mode , val isTrainMode : Boolean , val saveStep : (AIGameStep <Block >) -> Unit = {}
45+ ) : OnnxModel<Game<Block>> {
2346 private val env: OrtEnvironment = OrtEnvironment .getEnvironment()
2447 private val sessionOptions: OrtSession .SessionOptions = OrtSession .SessionOptions ().apply {
2548 if (mode == Mode .GPU ) {
@@ -34,48 +57,47 @@ class OnnxModelImpl<Block: BasicBlock>(
3457 val stateIds = mutableMapOf<StateId , Int >()
3558 val input = generateInput(game, stateIds)
3659 val output = session.run (input)
37-
38- val predictedStatesRanks =
39- (output[" out" ].get().value as Array <* >).map { (it as FloatArray ).toList() }
40-
60+ val predictedStates = output[" out" ].get().value as Array <* >
61+ val predictedStatesRanks = predictedStates.map { (it as FloatArray ).toList() }
62+ saveStep(AIGameStep (game, predictedStates.map { it as FloatArray }.toTypedArray()))
4163 return checkNotNull(getPredictedState(predictedStatesRanks, stateIds))
4264 }
4365
4466 private fun generateInput (
45- game : Game <Block >,
46- stateIds : MutableMap <StateId , Int >
67+ game : Game <Block >, stateIds : MutableMap <StateId , Int >
4768 ): Map <String , OnnxTensor > {
4869 val (vertices, stateWrappers, blockGraph) = game
4970 val vertexIds = mutableMapOf<Int , Int >()
5071 val gameVertices = tensorFromBasicBlocks(vertices, vertexIds)
5172 val states = tensorFromStates(stateWrappers, stateIds)
5273 val (vertexToVertexEdgesIndex, vertexToVertexEdgesAttributes) = tensorFromVertexEdges(
53- vertices,
54- blockGraph,
55- vertexIds
74+ vertices, blockGraph, vertexIds
5675 )
5776 val (parentOfEdges, historyEdgesIndexVertexToState, historyEdgesAttributes) = tensorFromStateEdges(
58- stateWrappers,
59- stateIds,
60- vertexIds
77+ stateWrappers, stateIds, vertexIds
6178 )
6279 val vertexToState = tensorFromStatePositions(stateWrappers, stateIds, vertexIds)
63-
80+ val mockPC1 = createTensor(env, FloatBuffer .wrap(FloatArray (0 )), longArrayOf(0 , 49 ))
81+ val mockPC2 = createTensor(env, LongBuffer .wrap(LongArray (0 )), longArrayOf(2 , 0 ))
82+ val mockPC3 =
83+ createTensor(env, LongBuffer .wrap(LongArray (0 )), longArrayOf(2 , 0 ))
6484 return mapOf (
6585 " game_vertex" to gameVertices,
6686 " state_vertex" to states,
87+ " path_condition_vertex" to mockPC1,
6788 " gamevertex_to_gamevertex_index" to vertexToVertexEdgesIndex,
6889 " gamevertex_to_gamevertex_type" to vertexToVertexEdgesAttributes,
6990 " gamevertex_history_statevertex_index" to historyEdgesIndexVertexToState,
7091 " gamevertex_history_statevertex_attrs" to historyEdgesAttributes,
7192 " gamevertex_in_statevertex" to vertexToState,
72- " statevertex_parentof_statevertex" to parentOfEdges
93+ " statevertex_parentof_statevertex" to parentOfEdges,
94+ " pathconditionvertex_to_pathconditionvertex" to mockPC2,
95+ " pathconditionvertex_to_statevertex" to mockPC3,
7396 )
7497 }
7598
7699 private fun getPredictedState (stateRank : List <List <Float >>, stateIds : Map <StateId , Int >): StateId ? {
77- return stateRank
78- .mapIndexed { index, ranks -> stateIds.entries.find { it.value == index }?.key to ranks.sum() }
100+ return stateRank.mapIndexed { index, ranks -> stateIds.entries.find { it.value == index }?.key to ranks.sum() }
79101 .maxBy { it.second }.first
80102 }
81103
@@ -97,13 +119,14 @@ class OnnxModelImpl<Block: BasicBlock>(
97119 return createFloatTensor(verticesArray, vertices.size, numOfVertexAttributes)
98120 }
99121
100- private fun tensorFromStates (states : Collection <StateWrapper <* , * , * >>, stateIds : MutableMap <StateId , Int >): OnnxTensor {
101- val numOfStateAttributes = 7
122+ private fun tensorFromStates (
123+ states : Collection <StateWrapper <* , * , * >>, stateIds : MutableMap <StateId , Int >
124+ ): OnnxTensor {
125+ val numOfStateAttributes = 6
102126 val statesArray = states.flatMapIndexed { i, state ->
103127 stateIds[state.id] = i
104128 listOf (
105129 state.position,
106- state.pathConditionSize,
107130 state.visitedAgainVertices,
108131 state.visitedNotCoveredVerticesInZone,
109132 state.visitedNotCoveredVerticesOutOfZone,
@@ -116,9 +139,7 @@ class OnnxModelImpl<Block: BasicBlock>(
116139 }
117140
118141 private fun tensorFromVertexEdges (
119- blocks : Collection <Block >,
120- blockGraph : BlockGraph <* , Block , * >,
121- vertexIds : MutableMap <Int , Int >
142+ blocks : Collection <Block >, blockGraph : BlockGraph <* , Block , * >, vertexIds : MutableMap <Int , Int >
122143 ): Pair <OnnxTensor , OnnxTensor > {
123144 val vertexFrom = mutableListOf<Long >()
124145 val vertexTo = mutableListOf<Long >()
@@ -136,15 +157,12 @@ class OnnxModelImpl<Block: BasicBlock>(
136157 val indexList = (vertexFrom + vertexTo).toLongArray()
137158
138159 return createLongTensor(indexList, 2 , vertexFrom.size) to createLongTensor(
139- attributes.toLongArray(),
140- attributes.size
160+ attributes.toLongArray(), attributes.size
141161 )
142162 }
143163
144164 private fun tensorFromStateEdges (
145- states : Collection <StateWrapper <* , * , * >>,
146- stateIds : MutableMap <StateId , Int >,
147- vertexIds : MutableMap <Int , Int >
165+ states : Collection <StateWrapper <* , * , * >>, stateIds : MutableMap <StateId , Int >, vertexIds : MutableMap <Int , Int >
148166 ): Triple <OnnxTensor , OnnxTensor , OnnxTensor > {
149167 val numOfParentOfEdges = states.sumOf { it.children.size }
150168 val numOfHistoryEdges = states.sumOf { it.history.size }
@@ -173,20 +191,15 @@ class OnnxModelImpl<Block: BasicBlock>(
173191 }
174192
175193 val parentTensor = createLongTensor(parentOf, 2 , numOfParentOfEdges)
176- val historyIndexTensor =
177- createLongTensor(historyIndexVertexToState, 2 , numOfHistoryEdges)
194+ val historyIndexTensor = createLongTensor(historyIndexVertexToState, 2 , numOfHistoryEdges)
178195 val historyAttributesTensor = createLongTensor(
179- historyAttributes,
180- numOfHistoryEdges,
181- numOfHistoryEdgeAttributes
196+ historyAttributes, numOfHistoryEdges, numOfHistoryEdgeAttributes
182197 )
183198 return Triple (parentTensor, historyIndexTensor, historyAttributesTensor)
184199 }
185200
186201 private fun tensorFromStatePositions (
187- states : Collection <StateWrapper <* , * , * >>,
188- stateIds : MutableMap <StateId , Int >,
189- vertexIds : MutableMap <Int , Int >
202+ states : Collection <StateWrapper <* , * , * >>, stateIds : MutableMap <StateId , Int >, vertexIds : MutableMap <Int , Int >
190203 ): OnnxTensor {
191204 val totalStates = states.size
192205 val vertexToState = LongArray (2 * totalStates)
0 commit comments