Skip to content

Commit 0e97f9c

Browse files
piotroootzolov
authored andcommitted
Introduce AzureOpenAI transcription support
- Breaking changes: Classes from the org.springframework.ai.openai.metadata.audio.transcription package have been moved to the org.springframework.ai.audio.transcription package. - The AzureOpenAiAudioTranscriptionModel has been added to the auto-configuration. - The spring.ai.azure.openai.audio.transcription prefix was introduced for properties. - Introduces options properties which cover all of them (see: AzureOpenAiAudioTranscriptionOptions). - fix missing MutableResponseMetadata - add docs - adjust code to updated ResponseMetadata design - add test to AzureOpenAiAutoConfiguration - add missing AzureOpenAiAudioTranscriptionModel tests
1 parent 92ec519 commit 0e97f9c

File tree

31 files changed

+1087
-117
lines changed

31 files changed

+1087
-117
lines changed
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
/*
2+
* Copyright 2023 - 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.ai.azure.openai;
17+
18+
import com.azure.ai.openai.OpenAIClient;
19+
import com.azure.ai.openai.models.AudioTranscriptionFormat;
20+
import com.azure.ai.openai.models.AudioTranscriptionOptions;
21+
import com.azure.ai.openai.models.AudioTranscriptionTimestampGranularity;
22+
import com.azure.core.http.rest.Response;
23+
import org.springframework.ai.audio.transcription.AudioTranscription;
24+
import org.springframework.ai.audio.transcription.AudioTranscriptionPrompt;
25+
import org.springframework.ai.audio.transcription.AudioTranscriptionResponse;
26+
import org.springframework.ai.azure.openai.AzureOpenAiAudioTranscriptionOptions.GranularityType;
27+
import org.springframework.ai.azure.openai.AzureOpenAiAudioTranscriptionOptions.StructuredResponse;
28+
import org.springframework.ai.azure.openai.AzureOpenAiAudioTranscriptionOptions.StructuredResponse.Segment;
29+
import org.springframework.ai.azure.openai.AzureOpenAiAudioTranscriptionOptions.StructuredResponse.Word;
30+
import org.springframework.ai.azure.openai.AzureOpenAiAudioTranscriptionOptions.TranscriptResponseFormat;
31+
import org.springframework.ai.azure.openai.metadata.AzureOpenAiAudioTranscriptionResponseMetadata;
32+
import org.springframework.ai.model.Model;
33+
import org.springframework.ai.model.ModelOptionsUtils;
34+
import org.springframework.core.io.Resource;
35+
import org.springframework.util.Assert;
36+
import org.springframework.util.StringUtils;
37+
38+
import java.io.IOException;
39+
import java.util.List;
40+
41+
/**
42+
* AzureOpenAI audio transcription client implementation for backed by
43+
* {@link OpenAIClient}. You provide as input the audio file you want to transcribe and
44+
* the desired output file format of the transcription of the audio.
45+
*
46+
* @author Piotr Olaszewski
47+
*/
48+
public class AzureOpenAiAudioTranscriptionModel implements Model<AudioTranscriptionPrompt, AudioTranscriptionResponse> {
49+
50+
private static final List<AudioTranscriptionFormat> JSON_FORMATS = List.of(AudioTranscriptionFormat.JSON,
51+
AudioTranscriptionFormat.VERBOSE_JSON);
52+
53+
private static final String FILENAME_MARKER = "filename.wav";
54+
55+
private final OpenAIClient openAIClient;
56+
57+
private final AzureOpenAiAudioTranscriptionOptions defaultOptions;
58+
59+
public AzureOpenAiAudioTranscriptionModel(OpenAIClient openAIClient, AzureOpenAiAudioTranscriptionOptions options) {
60+
this.openAIClient = openAIClient;
61+
this.defaultOptions = options;
62+
}
63+
64+
public String call(Resource audioResource) {
65+
AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(audioResource);
66+
return call(transcriptionRequest).getResult().getOutput();
67+
}
68+
69+
@Override
70+
public AudioTranscriptionResponse call(AudioTranscriptionPrompt audioTranscriptionPrompt) {
71+
String deploymentOrModelName = getDeploymentName(audioTranscriptionPrompt);
72+
AudioTranscriptionOptions audioTranscriptionOptions = toAudioTranscriptionOptions(audioTranscriptionPrompt);
73+
74+
AudioTranscriptionFormat responseFormat = audioTranscriptionOptions.getResponseFormat();
75+
if (JSON_FORMATS.contains(responseFormat)) {
76+
var audioTranscription = openAIClient.getAudioTranscription(deploymentOrModelName, FILENAME_MARKER,
77+
audioTranscriptionOptions);
78+
79+
List<Word> words = null;
80+
if (audioTranscription.getWords() != null) {
81+
words = audioTranscription.getWords().stream().map(w -> {
82+
float start = (float) w.getStart().toSeconds();
83+
float end = (float) w.getEnd().toSeconds();
84+
return new Word(w.getWord(), start, end);
85+
}).toList();
86+
}
87+
88+
List<Segment> segments = null;
89+
if (audioTranscription.getSegments() != null) {
90+
segments = audioTranscription.getSegments().stream().map(s -> {
91+
float start = (float) s.getStart().toSeconds();
92+
float end = (float) s.getEnd().toSeconds();
93+
return new Segment(s.getId(), s.getSeek(), start, end, s.getText(), s.getTokens(),
94+
(float) s.getTemperature(), (float) s.getAvgLogprob(), (float) s.getCompressionRatio(),
95+
(float) s.getNoSpeechProb());
96+
}).toList();
97+
}
98+
99+
Float duration = audioTranscription.getDuration() == null ? null
100+
: (float) audioTranscription.getDuration().toSeconds();
101+
StructuredResponse structuredResponse = new StructuredResponse(audioTranscription.getLanguage(), duration,
102+
audioTranscription.getText(), words, segments);
103+
104+
AudioTranscription transcript = new AudioTranscription(structuredResponse.text());
105+
AzureOpenAiAudioTranscriptionResponseMetadata metadata = AzureOpenAiAudioTranscriptionResponseMetadata
106+
.from(structuredResponse);
107+
108+
return new AudioTranscriptionResponse(transcript, metadata);
109+
}
110+
else {
111+
Response<String> audioTranscription = openAIClient.getAudioTranscriptionTextWithResponse(
112+
deploymentOrModelName, FILENAME_MARKER, audioTranscriptionOptions, null);
113+
String text = audioTranscription.getValue();
114+
AudioTranscription transcript = new AudioTranscription(text);
115+
return new AudioTranscriptionResponse(transcript, AzureOpenAiAudioTranscriptionResponseMetadata.from(text));
116+
}
117+
}
118+
119+
private String getDeploymentName(AudioTranscriptionPrompt audioTranscriptionPrompt) {
120+
var runtimeOptions = audioTranscriptionPrompt.getOptions();
121+
122+
if (defaultOptions != null) {
123+
runtimeOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions,
124+
AzureOpenAiAudioTranscriptionOptions.class);
125+
}
126+
127+
if (runtimeOptions instanceof AzureOpenAiAudioTranscriptionOptions azureOpenAiAudioTranscriptionOptions) {
128+
String deploymentName = azureOpenAiAudioTranscriptionOptions.getDeploymentName();
129+
if (StringUtils.hasText(deploymentName)) {
130+
return deploymentName;
131+
}
132+
}
133+
134+
return runtimeOptions.getModel();
135+
}
136+
137+
private AudioTranscriptionOptions toAudioTranscriptionOptions(AudioTranscriptionPrompt audioTranscriptionPrompt) {
138+
var runtimeOptions = audioTranscriptionPrompt.getOptions();
139+
140+
if (this.defaultOptions != null) {
141+
runtimeOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions,
142+
AzureOpenAiAudioTranscriptionOptions.class);
143+
}
144+
145+
byte[] bytes = toBytes(audioTranscriptionPrompt.getInstructions());
146+
AudioTranscriptionOptions audioTranscriptionOptions = new AudioTranscriptionOptions(bytes);
147+
148+
if (runtimeOptions instanceof AzureOpenAiAudioTranscriptionOptions azureOpenAiAudioTranscriptionOptions) {
149+
String model = azureOpenAiAudioTranscriptionOptions.getModel();
150+
if (StringUtils.hasText(model)) {
151+
audioTranscriptionOptions.setModel(model);
152+
}
153+
154+
String language = azureOpenAiAudioTranscriptionOptions.getLanguage();
155+
if (StringUtils.hasText(language)) {
156+
audioTranscriptionOptions.setLanguage(language);
157+
}
158+
159+
String prompt = azureOpenAiAudioTranscriptionOptions.getPrompt();
160+
if (StringUtils.hasText(prompt)) {
161+
audioTranscriptionOptions.setPrompt(prompt);
162+
}
163+
164+
Float temperature = azureOpenAiAudioTranscriptionOptions.getTemperature();
165+
if (temperature != null) {
166+
audioTranscriptionOptions.setTemperature(temperature.doubleValue());
167+
}
168+
169+
TranscriptResponseFormat responseFormat = azureOpenAiAudioTranscriptionOptions.getResponseFormat();
170+
List<GranularityType> granularityType = azureOpenAiAudioTranscriptionOptions.getGranularityType();
171+
172+
if (responseFormat != null) {
173+
audioTranscriptionOptions.setResponseFormat(responseFormat.getValue());
174+
if (responseFormat == TranscriptResponseFormat.VERBOSE_JSON && granularityType == null) {
175+
granularityType = List.of(GranularityType.SEGMENT);
176+
}
177+
}
178+
179+
if (granularityType != null) {
180+
Assert.isTrue(responseFormat == TranscriptResponseFormat.VERBOSE_JSON,
181+
"response_format must be set to verbose_json to use timestamp granularities.");
182+
List<AudioTranscriptionTimestampGranularity> granularity = granularityType.stream()
183+
.map(GranularityType::getValue)
184+
.toList();
185+
audioTranscriptionOptions.setTimestampGranularities(granularity);
186+
}
187+
}
188+
189+
return audioTranscriptionOptions;
190+
}
191+
192+
private static byte[] toBytes(Resource resource) {
193+
try {
194+
return resource.getInputStream().readAllBytes();
195+
}
196+
catch (IOException e) {
197+
throw new IllegalArgumentException("Failed to read resource: " + resource, e);
198+
}
199+
}
200+
201+
}

0 commit comments

Comments
 (0)