Skip to content

Commit 2954660

Browse files
committed
Fix plots_stats breaking because of new saving system
1 parent ce69570 commit 2954660

File tree

1 file changed

+41
-69
lines changed

1 file changed

+41
-69
lines changed

main.py

Lines changed: 41 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -8,81 +8,48 @@
88

99
class plots_stats:
1010
def printWinPercentage():
11-
prisonersLog = os.path.join(working_dir, 'results.csv')
12-
simResults = {}
13-
maxSimId = -1
14-
with open(prisonersLog, mode='r', newline='') as file:
15-
reader = csv.DictReader(file)
16-
for row in reader:
17-
simId = int(row['Simulation'])
18-
maxSimId = max(maxSimId, simId)
19-
found = row['FoundBox'] == 'True'
20-
if simId not in simResults:
21-
simResults[simId] = []
22-
simResults[simId].append(found)
23-
24-
if maxSimId == -1:
25-
print("No simulations found.")
26-
return
27-
28-
wins = sum(all(results) for results in simResults.values())
29-
totalSims = maxSimId + 1
11+
results = saving.loadResults()
12+
total_sims = len(results)
13+
wins = sum(1 for result in results if result["escaped"])
3014

31-
winPercentage = (wins / totalSims) * 100
15+
winPercentage = (wins / total_sims) * 100
3216
winStr = f"{winPercentage:.10f}".rstrip('0').rstrip('.')
3317
print(f"\nWin percentage: {winStr}%")
3418

3519
def printAvgBoxChecks():
36-
prisonersLog = os.path.join(working_dir, 'results.csv')
37-
num_prisoners = -1
38-
num_simulations = -1
39-
simResults = {}
40-
avgChecksPerPrisoner = {}
41-
42-
with open(prisonersLog, mode='r', newline='') as file:
43-
reader = csv.DictReader(file)
44-
for row in reader:
45-
simId = int(row['Simulation'])
46-
prisoner = int(row['PrisonerID'])
47-
checkedBoxesCount = int(row['CheckedBoxesCount'])
48-
num_prisoners = max(num_prisoners, prisoner + 1)
49-
num_simulations = max(num_simulations, simId + 1)
50-
if simId not in simResults:
51-
simResults[simId] = {}
52-
simResults[simId][prisoner] = checkedBoxesCount
53-
54-
for prisoner in range(num_prisoners):
55-
totalChecks = sum(simResults[simId].get(prisoner, 0) for simId in simResults)
56-
avgChecksPerPrisoner[prisoner] = totalChecks / num_simulations
57-
overall_avg = sum(avgChecksPerPrisoner.values()) / len(avgChecksPerPrisoner)
20+
results = saving.loadResults()
21+
total_sims = len(results)
22+
num_prisoners = len(results[0]["prisoners"])
23+
checksPerPrisoner = {i: 0 for i in range(num_prisoners)}
24+
25+
for result in results:
26+
for prisonerId, prisoner_data in enumerate(result["prisoners"]):
27+
checksPerPrisoner[prisonerId] += len(prisoner_data["checked_boxes"])
28+
avgChecksPerPrisoner = {prisoner: checks / total_sims for prisoner, checks in checksPerPrisoner.items()}
5829

30+
overallAvg = sum(avgChecksPerPrisoner.values()) / num_prisoners
31+
5932
plt.bar(avgChecksPerPrisoner.keys(), avgChecksPerPrisoner.values())
60-
plt.axhline(y=overall_avg, color='r', linestyle='-', label=f'Overall Average: {overall_avg:.2f}')
33+
plt.axhline(y=overallAvg, color='r', linestyle='-', label=f'Overall Average: {overallAvg:.2f}')
6134
plt.xlabel("Prisoner ID")
6235
plt.ylabel("Average Checked Boxes")
6336
plt.title("Average Checked Boxes per Prisoner")
6437
plt.legend()
6538
plt.show()
6639

6740
def printPctFinds():
68-
prisonersLog = os.path.join(working_dir, 'results.csv')
69-
num_prisoners = -1
70-
num_simulations = -1
71-
findCounts = {}
72-
with open(prisonersLog, mode='r', newline='') as file:
73-
reader = csv.DictReader(file)
74-
for row in reader:
75-
prisoner = int(row['PrisonerID'])
76-
simId = int(row['Simulation'])
77-
found = row['FoundBox'] == 'True'
78-
num_prisoners = max(num_prisoners, prisoner + 1)
79-
num_simulations = max(num_simulations, simId + 1)
80-
if prisoner not in findCounts:
81-
findCounts[prisoner] = 0
82-
if found:
83-
findCounts[prisoner] += 1
84-
pctFinds = {prisoner: (count / num_simulations) * 100 for prisoner, count in findCounts.items()}
85-
avgPctFinds = sum(pctFinds.values()) / len(pctFinds)
41+
results = saving.loadResults()
42+
total_sims = len(results)
43+
num_prisoners = len(results[0]["prisoners"])
44+
findsPerPrisoner = {i: 0 for i in range(num_prisoners)}
45+
46+
for result in results:
47+
for prisonerId, prisoner_data in enumerate(result["prisoners"]):
48+
if prisoner_data["found"]:
49+
findsPerPrisoner[prisonerId] += 1
50+
pctFinds = {prisoner: (finds / total_sims) * 100 for prisoner, finds in findsPerPrisoner.items()}
51+
52+
avgPctFinds = sum(pctFinds.values()) / num_prisoners
8653

8754
plt.bar(pctFinds.keys(), pctFinds.values())
8855
plt.axhline(y=avgPctFinds, color='r', linestyle='-', label=f'Overall Average: {avgPctFinds:.2f}%')
@@ -94,9 +61,8 @@ def printPctFinds():
9461
plt.show()
9562

9663
def run():
97-
prisonersLog = os.path.join(working_dir, 'results.csv')
98-
if not os.path.exists(prisonersLog):
99-
print("No results file found.")
64+
if not os.path.exists(resultsPath):
65+
print("No results found. Please run simulations first.")
10066
return
10167

10268
while True:
@@ -128,14 +94,19 @@ def save(results, checkpoint):
12894
os.replace(resultsPath + '.tmp', resultsPath)
12995
os.replace(checkpointPath + '.tmp', checkpointPath)
13096

131-
def load():
132-
if os.path.exists(resultsPath) and os.path.exists(checkpointPath):
97+
def loadResults():
98+
if os.path.exists(resultsPath):
13399
with open(resultsPath, 'rb') as file:
134100
results = pickle.load(file)
101+
return results
102+
return None
103+
104+
def loadCheckpoint():
105+
if os.path.exists(checkpointPath):
135106
with open(checkpointPath, 'rb') as file:
136107
checkpoint = pickle.load(file)
137-
return results, checkpoint
138-
return None, None
108+
return checkpoint
109+
return None
139110

140111
def importConfigModule():
141112
configPath = os.path.join(working_dir, "config.py")
@@ -157,7 +128,8 @@ def getWorkingDir():
157128
print(f"Directory {working_dir} does not exist. Please try again.")
158129

159130
def simulatePrisoners():
160-
results, checkpoint = saving.load()
131+
results = saving.loadResults()
132+
checkpoint = saving.loadCheckpoint()
161133
if checkpoint:
162134
startSim = checkpoint.get("last_simulation") + 1
163135
rng = random.Random(cfg.get("seed", None))

0 commit comments

Comments
 (0)