Skip to content

Commit 006e6a9

Browse files
authored
Merge pull request #16 from richardpark-msft/azure-update-readme
Adding in Azure support
2 parents 5ffed6a + 1fc3338 commit 006e6a9

File tree

6 files changed

+501
-5
lines changed

6 files changed

+501
-5
lines changed

README.md

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,43 @@ You may also replace the default `http.Client` with
395395
accepted (this overwrites any previous client) and receives requests after any
396396
middleware has been applied.
397397

398+
## Microsoft Azure OpenAI
399+
400+
To use this library with [Azure OpenAI](https://learn.microsoft.com/azure/ai-services/openai/overview), use the option.RequestOption functions in the `azure` package.
401+
402+
```go
403+
package main
404+
405+
import (
406+
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
407+
"github.com/openai/openai-go"
408+
"github.com/openai/openai-go/azure"
409+
)
410+
411+
func main() {
412+
const azureOpenAIEndpoint = "https://<azure-openai-resource>.openai.azure.com"
413+
414+
// The latest API versions, including previews, can be found here:
415+
// https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning
416+
const azureOpenAIAPIVersion = "2024-06-01"
417+
418+
tokenCredential, err := azidentity.NewDefaultAzureCredential(nil)
419+
420+
if err != nil {
421+
fmt.Printf("Failed to create the DefaultAzureCredential: %s", err)
422+
os.Exit(1)
423+
}
424+
425+
client := openai.NewClient(
426+
azure.WithEndpoint(azureOpenAIEndpoint, azureOpenAIAPIVersion),
427+
428+
// Choose between authenticating using a TokenCredential or an API Key
429+
azure.WithTokenCredential(tokenCredential),
430+
// or azure.WithAPIKey(azureOpenAIAPIKey),
431+
)
432+
}
433+
```
434+
398435
## Semantic versioning
399436

400437
This package generally follows [SemVer](https://semver.org/spec/v2.0.0.html) conventions, though certain backwards-incompatible changes may be released as minor versions:

azure/azure.go

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
// Package azure provides configuration options so you can connect and use Azure OpenAI using the [openai.Client].
2+
//
3+
// Typical usage of this package will look like this:
4+
//
5+
// client := openai.NewClient(
6+
// azure.WithEndpoint(azureOpenAIEndpoint, azureOpenAIAPIVersion),
7+
// azure.WithTokenCredential(azureIdentityTokenCredential),
8+
// // or azure.WithAPIKey(azureOpenAIAPIKey),
9+
// )
10+
//
11+
// Or, if you want to construct a specific service:
12+
//
13+
// client := openai.NewChatCompletionService(
14+
// azure.WithEndpoint(azureOpenAIEndpoint, azureOpenAIAPIVersion),
15+
// azure.WithTokenCredential(azureIdentityTokenCredential),
16+
// // or azure.WithAPIKey(azureOpenAIAPIKey),
17+
// )
18+
package azure
19+
20+
import (
21+
"bytes"
22+
"encoding/json"
23+
"errors"
24+
"io"
25+
"mime"
26+
"mime/multipart"
27+
"net/http"
28+
"net/url"
29+
"strings"
30+
31+
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
32+
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
33+
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
34+
"github.com/openai/openai-go/internal/requestconfig"
35+
"github.com/openai/openai-go/option"
36+
)
37+
38+
// WithEndpoint configures this client to connect to an Azure OpenAI endpoint.
39+
//
40+
// - endpoint - the Azure OpenAI endpoint to connect to. Ex: https://<azure-openai-resource>.openai.azure.com
41+
// - apiVersion - the Azure OpenAI API version to target (ex: 2024-06-01). See [Azure OpenAI apiversions] for current API versions. This value cannot be empty.
42+
//
43+
// This function should be paired with a call to authenticate, like [azure.WithAPIKey] or [azure.WithTokenCredential], similar to this:
44+
//
45+
// client := openai.NewClient(
46+
// azure.WithEndpoint(azureOpenAIEndpoint, azureOpenAIAPIVersion),
47+
// azure.WithTokenCredential(azureIdentityTokenCredential),
48+
// // or azure.WithAPIKey(azureOpenAIAPIKey),
49+
// )
50+
//
51+
// [Azure OpenAI apiversions]: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning
52+
func WithEndpoint(endpoint string, apiVersion string) option.RequestOption {
53+
if !strings.HasSuffix(endpoint, "/") {
54+
endpoint += "/"
55+
}
56+
57+
endpoint += "openai/"
58+
59+
withQueryAdd := option.WithQueryAdd("api-version", apiVersion)
60+
withEndpoint := option.WithBaseURL(endpoint)
61+
62+
withModelMiddleware := option.WithMiddleware(func(r *http.Request, mn option.MiddlewareNext) (*http.Response, error) {
63+
replacementPath, err := getReplacementPathWithDeployment(r)
64+
65+
if err != nil {
66+
return nil, err
67+
}
68+
69+
r.URL.Path = replacementPath
70+
return mn(r)
71+
})
72+
73+
return func(rc *requestconfig.RequestConfig) error {
74+
if apiVersion == "" {
75+
return errors.New("apiVersion is an empty string, but needs to be set. See https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning for details.")
76+
}
77+
78+
if err := withQueryAdd(rc); err != nil {
79+
return err
80+
}
81+
82+
if err := withEndpoint(rc); err != nil {
83+
return err
84+
}
85+
86+
if err := withModelMiddleware(rc); err != nil {
87+
return err
88+
}
89+
90+
return nil
91+
}
92+
}
93+
94+
// WithTokenCredential configures this client to authenticate using an [Azure Identity] TokenCredential.
95+
// This function should be paired with a call to [WithEndpoint] to point to your Azure OpenAI instance.
96+
//
97+
// [Azure Identity]: https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/azidentity
98+
func WithTokenCredential(tokenCredential azcore.TokenCredential) option.RequestOption {
99+
bearerTokenPolicy := runtime.NewBearerTokenPolicy(tokenCredential, []string{"https://cognitiveservices.azure.com/.default"}, nil)
100+
101+
// add in a middleware that uses the bearer token generated from the token credential
102+
return option.WithMiddleware(func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) {
103+
pipeline := runtime.NewPipeline("azopenai-extensions", version, runtime.PipelineOptions{}, &policy.ClientOptions{
104+
InsecureAllowCredentialWithHTTP: true, // allow for plain HTTP proxies, etc..
105+
PerRetryPolicies: []policy.Policy{
106+
bearerTokenPolicy,
107+
policyAdapter(next),
108+
},
109+
})
110+
111+
req2, err := runtime.NewRequestFromRequest(req)
112+
113+
if err != nil {
114+
return nil, err
115+
}
116+
117+
return pipeline.Do(req2)
118+
})
119+
}
120+
121+
// WithAPIKey configures this client to authenticate using an API key.
122+
// This function should be paired with a call to [WithEndpoint] to point to your Azure OpenAI instance.
123+
func WithAPIKey(apiKey string) option.RequestOption {
124+
// NOTE: there is an option.WithApiKey(), but that adds the value into
125+
// the Authorization header instead so we're doing this instead.
126+
return option.WithHeader("Api-Key", apiKey)
127+
}
128+
129+
// jsonRoutes have JSON payloads - we'll deserialize looking for a .model field in there
130+
// so we won't have to worry about individual types for completions vs embeddings, etc...
131+
var jsonRoutes = map[string]bool{
132+
"/openai/completions": true,
133+
"/openai/chat/completions": true,
134+
"/openai/embeddings": true,
135+
"/openai/audio/speech": true,
136+
"/openai/images/generations": true,
137+
}
138+
139+
// audioMultipartRoutes have mime/multipart payloads. These are less generic - we're very much
140+
// expecting a transcription or translation payload for these.
141+
var audioMultipartRoutes = map[string]bool{
142+
"/openai/audio/transcriptions": true,
143+
"/openai/audio/translations": true,
144+
}
145+
146+
// getReplacementPathWithDeployment parses the request body to extract out the Model parameter (or equivalent)
147+
// (note, the req.Body is fully read as part of this, and is replaced with a bytes.Reader)
148+
func getReplacementPathWithDeployment(req *http.Request) (string, error) {
149+
if jsonRoutes[req.URL.Path] {
150+
return getJSONRoute(req)
151+
}
152+
153+
if audioMultipartRoutes[req.URL.Path] {
154+
return getAudioMultipartRoute(req)
155+
}
156+
157+
// No need to relocate the path. We've already tacked on /openai when we setup the endpoint.
158+
return req.URL.Path, nil
159+
}
160+
161+
func getJSONRoute(req *http.Request) (string, error) {
162+
// we need to deserialize the body, partly, in order to read out the model field.
163+
jsonBytes, err := io.ReadAll(req.Body)
164+
165+
if err != nil {
166+
return "", err
167+
}
168+
169+
// make sure we restore the body so it can be used in later middlewares.
170+
req.Body = io.NopCloser(bytes.NewReader(jsonBytes))
171+
172+
var v *struct {
173+
Model string `json:"model"`
174+
}
175+
176+
if err := json.Unmarshal(jsonBytes, &v); err != nil {
177+
return "", err
178+
}
179+
180+
escapedDeployment := url.PathEscape(v.Model)
181+
return strings.Replace(req.URL.Path, "/openai/", "/openai/deployments/"+escapedDeployment+"/", 1), nil
182+
}
183+
184+
func getAudioMultipartRoute(req *http.Request) (string, error) {
185+
// body is a multipart/mime body type instead.
186+
mimeBytes, err := io.ReadAll(req.Body)
187+
188+
if err != nil {
189+
return "", err
190+
}
191+
192+
// make sure we restore the body so it can be used in later middlewares.
193+
req.Body = io.NopCloser(bytes.NewReader(mimeBytes))
194+
195+
_, mimeParams, err := mime.ParseMediaType(req.Header.Get("Content-Type"))
196+
197+
if err != nil {
198+
return "", err
199+
}
200+
201+
mimeReader := multipart.NewReader(
202+
io.NopCloser(bytes.NewReader(mimeBytes)),
203+
mimeParams["boundary"])
204+
205+
for {
206+
mp, err := mimeReader.NextPart()
207+
208+
if err != nil {
209+
if errors.Is(err, io.EOF) {
210+
return "", errors.New("unable to find the model part in multipart body")
211+
}
212+
213+
return "", err
214+
}
215+
216+
defer mp.Close()
217+
218+
if mp.FormName() == "model" {
219+
modelBytes, err := io.ReadAll(mp)
220+
221+
if err != nil {
222+
return "", err
223+
}
224+
225+
escapedDeployment := url.PathEscape(string(modelBytes))
226+
return strings.Replace(req.URL.Path, "/openai/", "/openai/deployments/"+escapedDeployment+"/", 1), nil
227+
}
228+
}
229+
}
230+
231+
type policyAdapter option.MiddlewareNext
232+
233+
func (mp policyAdapter) Do(req *policy.Request) (*http.Response, error) {
234+
return (option.MiddlewareNext)(mp)(req.Raw())
235+
}
236+
237+
const version = "v.0.1.0"

0 commit comments

Comments
 (0)