@@ -5,9 +5,28 @@ import {
5
5
} from '../azure-openai/utils' ;
6
6
import { ProviderAPIConfig } from '../types' ;
7
7
8
+ const NON_INFERENCE_ENDPOINTS = [
9
+ 'createBatch' ,
10
+ 'retrieveBatch' ,
11
+ 'cancelBatch' ,
12
+ 'getBatchOutput' ,
13
+ 'listBatches' ,
14
+ 'uploadFile' ,
15
+ 'listFiles' ,
16
+ 'retrieveFile' ,
17
+ 'deleteFile' ,
18
+ 'retrieveFileContent' ,
19
+ ] ;
20
+
8
21
const AzureAIInferenceAPI : ProviderAPIConfig = {
9
- getBaseURL : ( { providerOptions } ) => {
22
+ getBaseURL : ( { providerOptions, fn } ) => {
10
23
const { provider, azureFoundryUrl } = providerOptions ;
24
+
25
+ // Azure Foundry URL includes `/deployments/<deployment>`, strip out and append openai for batches/finetunes
26
+ if ( fn && NON_INFERENCE_ENDPOINTS . includes ( fn ) ) {
27
+ return new URL ( azureFoundryUrl ?? '' ) . origin + '/openai' ;
28
+ }
29
+
11
30
if ( provider === GITHUB ) {
12
31
return 'https://models.inference.ai.azure.com' ;
13
32
}
@@ -17,7 +36,7 @@ const AzureAIInferenceAPI: ProviderAPIConfig = {
17
36
18
37
return '' ;
19
38
} ,
20
- headers : async ( { providerOptions } ) => {
39
+ headers : async ( { providerOptions, fn } ) => {
21
40
const {
22
41
apiKey,
23
42
azureExtraParams,
@@ -31,6 +50,13 @@ const AzureAIInferenceAPI: ProviderAPIConfig = {
31
50
...( azureDeploymentName && {
32
51
'azureml-model-deployment' : azureDeploymentName ,
33
52
} ) ,
53
+ ...( [ 'createTranscription' , 'createTranslation' , 'uploadFile' ] . includes (
54
+ fn
55
+ )
56
+ ? {
57
+ 'Content-Type' : 'multipart/form-data' ,
58
+ }
59
+ : { } ) ,
34
60
} ;
35
61
if ( azureAdToken ) {
36
62
headers [ 'Authorization' ] =
@@ -70,14 +96,37 @@ const AzureAIInferenceAPI: ProviderAPIConfig = {
70
96
}
71
97
return headers ;
72
98
} ,
73
- getEndpoint : ( { providerOptions, fn } ) => {
99
+ getEndpoint : ( { providerOptions, fn, gatewayRequestURL } ) => {
74
100
const { azureApiVersion, urlToFetch } = providerOptions ;
75
101
let mappedFn = fn ;
76
102
103
+ const urlObj = new URL ( gatewayRequestURL ) ;
104
+ const path = urlObj . pathname . replace ( '/v1' , '' ) ;
105
+ const searchParams = urlObj . searchParams ;
106
+
107
+ if ( azureApiVersion ) {
108
+ searchParams . set ( 'api-version' , azureApiVersion ) ;
109
+ }
110
+
77
111
const ENDPOINT_MAPPING : Record < string , string > = {
78
112
complete : '/completions' ,
79
113
chatComplete : '/chat/completions' ,
80
114
embed : '/embeddings' ,
115
+ realtime : '/realtime' ,
116
+ imageGenerate : '/images/generations' ,
117
+ createSpeech : '/audio/speech' ,
118
+ createTranscription : '/audio/transcriptions' ,
119
+ createTranslation : '/audio/translations' ,
120
+ uploadFile : path ,
121
+ retrieveFile : path ,
122
+ listFiles : path ,
123
+ deleteFile : path ,
124
+ retrieveFileContent : path ,
125
+ listBatches : path ,
126
+ retrieveBatch : path ,
127
+ cancelBatch : path ,
128
+ getBatchOutput : path ,
129
+ createBatch : path ,
81
130
} ;
82
131
83
132
const isGithub = providerOptions . provider === GITHUB ;
@@ -92,23 +141,40 @@ const AzureAIInferenceAPI: ProviderAPIConfig = {
92
141
}
93
142
}
94
143
95
- const apiVersion = azureApiVersion ? `?api-version= ${ azureApiVersion } ` : '' ;
144
+ const searchParamsString = searchParams . toString ( ) ;
96
145
switch ( mappedFn ) {
97
146
case 'complete' : {
98
147
return isGithub
99
148
? ENDPOINT_MAPPING [ mappedFn ]
100
- : `${ ENDPOINT_MAPPING [ mappedFn ] } ${ apiVersion } ` ;
149
+ : `${ ENDPOINT_MAPPING [ mappedFn ] } ? ${ searchParamsString } ` ;
101
150
}
102
151
case 'chatComplete' : {
103
152
return isGithub
104
153
? ENDPOINT_MAPPING [ mappedFn ]
105
- : `${ ENDPOINT_MAPPING [ mappedFn ] } ${ apiVersion } ` ;
154
+ : `${ ENDPOINT_MAPPING [ mappedFn ] } ? ${ searchParamsString } ` ;
106
155
}
107
156
case 'embed' : {
108
157
return isGithub
109
158
? ENDPOINT_MAPPING [ mappedFn ]
110
- : `${ ENDPOINT_MAPPING [ mappedFn ] } ${ apiVersion } ` ;
159
+ : `${ ENDPOINT_MAPPING [ mappedFn ] } ? ${ searchParamsString } ` ;
111
160
}
161
+ case 'realtime' :
162
+ case 'imageGenerate' :
163
+ case 'createSpeech' :
164
+ case 'createTranscription' :
165
+ case 'createTranslation' :
166
+ case 'cancelBatch' :
167
+ case 'createBatch' :
168
+ case 'getBatchOutput' :
169
+ case 'retrieveBatch' :
170
+ case 'listBatches' :
171
+ case 'retrieveFile' :
172
+ case 'listFiles' :
173
+ case 'deleteFile' :
174
+ case 'retrieveFileContent' : {
175
+ return `${ ENDPOINT_MAPPING [ mappedFn ] } ?${ searchParamsString } ` ;
176
+ }
177
+
112
178
default :
113
179
return '' ;
114
180
}
0 commit comments