Skip to content

Commit 3220cd5

Browse files
committed
Fix indices in ONNX converter.
1 parent 690daed commit 3220cd5

File tree

1 file changed

+17
-14
lines changed

1 file changed

+17
-14
lines changed

VSharp.Explorer/AISearcher.fs

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
9494
| Some(SendEachStep _) -> TrainingSendEachStep
9595
| Some(SendModel _) -> TrainingSendModel
9696
| None -> Runner
97+
9798
let pick selector =
9899
if useDefaultSearcher then
99100
defaultSearcherSteps <- defaultSearcherSteps + 1u<step>
@@ -122,7 +123,7 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
122123

123124
Application.applicationGraphDelta.Clear()
124125

125-
if stepsToPlay = stepsPlayed then
126+
if aiMode <> Runner && stepsToPlay = stepsPlayed then //TODO FIX IT CAREFULLY!!!!
126127
None
127128
else
128129
let toPredict =
@@ -136,6 +137,7 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
136137
| Runner -> gameState.Value
137138

138139
let stateId = oracle.Predict toPredict
140+
139141
afterFirstAIPeek <- true
140142
let state = availableStates |> Seq.tryFind (fun s -> s.internalId = stateId)
141143
lastCollectedStatistics <- statistics
@@ -147,6 +149,7 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
147149
incorrectPredictedStateId <- true
148150
oracle.Feedback(Feedback.IncorrectPredictedStateId stateId)
149151
None
152+
150153
static member updateGameState (delta: GameState) (gameState: Option<GameState>) =
151154
match gameState with
152155
| None -> Some delta
@@ -199,11 +202,11 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
199202
Some
200203
<| GameState(vertices.ToArray(), states, pathConditionVertices.ToArray(), edges.ToArray())
201204

202-
static member convertOutputToJson (output: IDisposableReadOnlyCollection<OrtValue>) =
205+
static member convertOutputToJson(output: IDisposableReadOnlyCollection<OrtValue>) =
203206
seq { 0 .. output.Count - 1 }
204207
|> Seq.map (fun i -> output[i].GetTensorDataAsSpan<float32>().ToArray())
205208

206-
209+
207210

208211
new
209212
(
@@ -309,11 +312,11 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
309312
stateIds.Add(v.Id, i)
310313
let j = i * numOfStateAttributes
311314
attributes.[j] <- float32 v.Position
312-
attributes.[j + 2] <- float32 v.VisitedAgainVertices
313-
attributes.[j + 3] <- float32 v.VisitedNotCoveredVerticesInZone
314-
attributes.[j + 4] <- float32 v.VisitedNotCoveredVerticesOutOfZone
315-
attributes.[j + 5] <- float32 v.StepWhenMovedLastTime
316-
attributes.[j + 6] <- float32 v.InstructionsVisitedInCurrentBlock
315+
attributes.[j + 1] <- float32 v.VisitedAgainVertices
316+
attributes.[j + 2] <- float32 v.VisitedNotCoveredVerticesInZone
317+
attributes.[j + 3] <- float32 v.VisitedNotCoveredVerticesOutOfZone
318+
attributes.[j + 4] <- float32 v.StepWhenMovedLastTime
319+
attributes.[j + 5] <- float32 v.InstructionsVisitedInCurrentBlock
317320

318321
OrtValue.CreateTensorValueFromMemory(attributes, shape), numOfParentOfEdges, numOfHistoryEdges
319322

@@ -325,17 +328,17 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
325328
for v in gameState.PathConditionVertices do
326329
for child in v.Children do
327330
// from vertex to child
328-
index.[firstFreePositionOfIndex] <- pathConditionVerticesIds.[v.Id]
331+
index.[firstFreePositionOfIndex] <- int64 pathConditionVerticesIds.[v.Id]
329332

330-
index.[firstFreePositionOfIndex + 2 * numOfPcToPcEdges] <-
331-
pathConditionVerticesIds.[child]
333+
index.[firstFreePositionOfIndex + numOfPcToPcEdges] <-
334+
int64 pathConditionVerticesIds.[child]
332335

333336
firstFreePositionOfIndex <- firstFreePositionOfIndex + 1
334337
// from child to vertex
335-
index.[firstFreePositionOfIndex] <- pathConditionVerticesIds.[child]
338+
index.[firstFreePositionOfIndex] <- int64 pathConditionVerticesIds.[child]
336339

337-
index.[firstFreePositionOfIndex + 2 * numOfPcToPcEdges] <-
338-
pathConditionVerticesIds.[v.Id]
340+
index.[firstFreePositionOfIndex + numOfPcToPcEdges] <-
341+
int64 pathConditionVerticesIds.[v.Id]
339342

340343
firstFreePositionOfIndex <- firstFreePositionOfIndex + 1
341344

0 commit comments

Comments
 (0)