Skip to content

Commit 43ddc23

Browse files
committed
fix inference scripts
1 parent 6857e13 commit 43ddc23

File tree

2 files changed

+24
-11
lines changed

2 files changed

+24
-11
lines changed

AI-and-Analytics/End-to-end-Workloads/LanguageIdentification/Inference/inference_commonVoice.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -173,12 +173,12 @@ def main(argv):
173173
data = datafile(testDataDirectory, filename)
174174
predict_list = []
175175
use_entire_audio_file = False
176-
if data.waveduration < sample_dur:
176+
if int(data.waveduration) <= sample_dur:
177177
# Use entire audio file if the duration is less than the sampling duration
178178
use_entire_audio_file = True
179179
sample_list = [0 for _ in range(sample_size)]
180180
else:
181-
start_time_list = list(range(sample_size - int(data.waveduration) + 1))
181+
start_time_list = list(range(0, int(data.waveduration) - sample_dur))
182182
sample_list = []
183183
for i in range(sample_size):
184184
sample_list.append(random.sample(start_time_list, 1)[0])
@@ -198,10 +198,6 @@ def main(argv):
198198
predict_list.append(' ')
199199
pass
200200

201-
# Clean up
202-
if use_entire_audio_file:
203-
os.remove("./" + data.filename)
204-
205201
# Pick the top rated prediction result
206202
occurence_count = Counter(predict_list)
207203
total_count = sum(occurence_count.values())

AI-and-Analytics/End-to-end-Workloads/LanguageIdentification/Inference/inference_custom.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -195,13 +195,13 @@ def main(argv):
195195
with open(OUTPUT_SUMMARY_CSV_FILE, 'w') as f:
196196
writer = csv.writer(f)
197197
writer.writerow(["Audio File",
198-
"Input Frequency",
198+
"Input Frequency (Hz)",
199199
"Expected Language",
200200
"Top Consensus",
201201
"Top Consensus %",
202202
"Second Consensus",
203203
"Second Consensus %",
204-
"Average Latency",
204+
"Average Latency (s)",
205205
"Result"])
206206

207207
total_samples = 0
@@ -273,12 +273,12 @@ def main(argv):
273273
predict_list = []
274274
use_entire_audio_file = False
275275
latency_sum = 0.0
276-
if data.waveduration < sample_dur:
276+
if int(data.waveduration) <= sample_dur:
277277
# Use entire audio file if the duration is less than the sampling duration
278278
use_entire_audio_file = True
279279
sample_list = [0 for _ in range(sample_size)]
280280
else:
281-
start_time_list = list(range(sample_size - int(data.waveduration) + 1))
281+
start_time_list = list(range(int(data.waveduration) - sample_dur))
282282
sample_list = []
283283
for i in range(sample_size):
284284
sample_list.append(random.sample(start_time_list, 1)[0])
@@ -346,11 +346,28 @@ def main(argv):
346346
avg_latency,
347347
result
348348
])
349+
else:
350+
# Write results to a .csv file
351+
with open(OUTPUT_SUMMARY_CSV_FILE, 'a') as f:
352+
writer = csv.writer(f)
353+
writer.writerow([
354+
filename,
355+
sample_rate_for_csv,
356+
"N/A",
357+
top_occurance,
358+
str(topPercentage) + "%",
359+
sec_occurance,
360+
str(secPercentage) + "%",
361+
avg_latency,
362+
"N/A"
363+
])
364+
349365

350366
if ground_truth_compare:
351367
# Summary of results
352368
print("\n\n Correctly predicted %d/%d\n" %(correct_predictions, total_samples))
353-
print("\n See %s for summary\n" %(OUTPUT_SUMMARY_CSV_FILE))
369+
370+
print("\n See %s for summary\n" %(OUTPUT_SUMMARY_CSV_FILE))
354371

355372
elif os.path.isfile(path):
356373
print("\nIt is a normal file", path)

0 commit comments

Comments
 (0)