Skip to content

Commit 5c9714c

Browse files
committed
improve whisper to work on 8 bit and 32bit wav too, also support form data for language
1 parent fa7e661 commit 5c9714c

File tree

2 files changed

+47
-13
lines changed

2 files changed

+47
-13
lines changed

koboldcpp.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1901,7 +1901,8 @@ def log_message(self, format, *args):
19011901
super().log_message(format, *args)
19021902
pass
19031903

1904-
def extract_b64string_from_file_upload(self, body):
1904+
def extract_transcribe_from_file_upload(self, body):
1905+
result = {"file": None, "prompt": None, "language": None}
19051906
try:
19061907
if 'content-type' in self.headers and self.headers['content-type']:
19071908
boundary = self.headers['content-type'].split("=")[1].encode()
@@ -1914,15 +1915,27 @@ def extract_b64string_from_file_upload(self, body):
19141915
file_content_start = fpart.find(b'\r\n\r\n') + 4 # Position after headers
19151916
file_content_end = fpart.rfind(b'\r\n') # Ending boundary
19161917
if file_content_start != -1 and file_content_end != -1:
1917-
file_data = fpart[file_content_start:file_content_end]
1918-
file_data_base64 = base64.b64encode(file_data).decode('utf-8',"ignore")
1919-
base64_string = f"data:audio/wav;base64,{file_data_base64}"
1920-
return base64_string
1921-
print("Uploaded file not found.")
1922-
return None
1918+
if "file" in result and result["file"] is None:
1919+
file_data = fpart[file_content_start:file_content_end]
1920+
file_data_base64 = base64.b64encode(file_data).decode('utf-8',"ignore")
1921+
base64_string = f"data:audio/wav;base64,{file_data_base64}"
1922+
result["file"] = base64_string
1923+
1924+
# Check for fields
1925+
detected_prompt_field = re.findall(r'Content-Disposition.*name="prompt"\r\n\r\n(.*)\r\n', fpart.decode('utf-8', errors='ignore'))
1926+
if detected_prompt_field and len(detected_prompt_field)>0:
1927+
result["prompt"] = detected_prompt_field[0].strip() # Extract and strip whitespace
1928+
1929+
detected_lang_field = re.findall(r'Content-Disposition.*name="language"\r\n\r\n(.*)\r\n', fpart.decode('utf-8', errors='ignore'))
1930+
if detected_lang_field and len(detected_lang_field)>0:
1931+
result["language"] = detected_lang_field[0].strip() # Extract and strip whitespace
1932+
1933+
if not ("file" in result and result["file"]):
1934+
print("Uploaded file not found.")
1935+
return result
19231936
except Exception as e:
19241937
print(f"File Upload Process Error: {e}")
1925-
return None
1938+
return result
19261939

19271940
async def generate_text(self, genparams, api_format, stream_flag):
19281941
global friendlymodelname, chatcompl_adapter, currfinishreason
@@ -2742,9 +2755,14 @@ def do_POST(self):
27422755
except Exception:
27432756
genparams = None
27442757
if is_transcribe: #fallback handling of file uploads
2745-
b64wav = self.extract_b64string_from_file_upload(body)
2746-
if b64wav:
2758+
formdata = self.extract_transcribe_from_file_upload(body)
2759+
if "file" in formdata and formdata["file"]:
2760+
b64wav = formdata["file"]
27472761
genparams = {"audio_data":b64wav}
2762+
if "prompt" in formdata and formdata["prompt"]:
2763+
genparams["prompt"] = formdata["prompt"]
2764+
if "language" in formdata and formdata["language"]:
2765+
genparams["language"] = formdata["language"]
27482766

27492767
if not genparams:
27502768
utfprint("Body Err: " + str(body))

otherarch/whispercpp/whisper_adapter.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ static bool read_wav(const std::string & b64data, std::vector<float>& pcmf32, st
5757
return false;
5858
}
5959

60-
if (wav.bitsPerSample != 16) {
61-
printf("WAV file must be 16-bit\n");
60+
if (wav.bitsPerSample != 8 && wav.bitsPerSample != 16 && wav.bitsPerSample != 32) {
61+
printf("WAV file must be 8-bit, 16-bit or 32-bit. Detected: %d\n",wav.bitsPerSample);
6262
drwav_uninit(&wav);
6363
return false;
6464
}
@@ -67,7 +67,23 @@ static bool read_wav(const std::string & b64data, std::vector<float>& pcmf32, st
6767

6868
std::vector<int16_t> pcm16;
6969
pcm16.resize(n*wav.channels);
70-
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
70+
71+
if (wav.bitsPerSample == 8) {
72+
// Handle 8-bit PCM and convert to 16-bit
73+
std::vector<uint8_t> pcm8(n * wav.channels);
74+
drwav_read_pcm_frames(&wav, n, pcm8.data());
75+
drwav_u8_to_s16(pcm16.data(), pcm8.data(), n * wav.channels);
76+
} else if (wav.bitsPerSample == 16) {
77+
// Handle 16-bit PCM directly
78+
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
79+
} else if (wav.bitsPerSample == 32) {
80+
// Handle 32-bit PCM and convert to 16-bit
81+
std::vector<int32_t> pcm32(n * wav.channels);
82+
drwav_read_pcm_frames_s32(&wav, n, pcm32.data());
83+
for (uint64_t i = 0; i < n * wav.channels; ++i) {
84+
pcm16[i] = static_cast<int16_t>(pcm32[i] >> 16); // Scale down by shifting
85+
}
86+
}
7187
drwav_uninit(&wav);
7288

7389
std::vector<float> raw_pcm;

0 commit comments

Comments
 (0)