22from typing import Any , List
33
44import openai
5- from openai import api_requestor , util
5+ from openai import api_requestor , error , util
66from openai .api_resources .abstract import APIResource
77
88
99class Image (APIResource ):
1010 OBJECT_NAME = "images"
1111
1212 @classmethod
13- def _get_url (cls , action ):
14- return cls .class_url () + f"/{ action } "
13+ def _get_url (cls , action , azure_action , api_type , api_version ):
14+ if api_type in (util .ApiType .AZURE , util .ApiType .AZURE_AD ) and azure_action is not None :
15+ return f"/{ cls .azure_api_prefix } { cls .class_url ()} /{ action } :{ azure_action } ?api-version={ api_version } "
16+ else :
17+ return f"{ cls .class_url ()} /{ action } "
1518
1619 @classmethod
1720 def create (
@@ -31,12 +34,20 @@ def create(
3134 organization = organization ,
3235 )
3336
34- _ , api_version = cls ._get_api_type_and_version (api_type , api_version )
37+ api_type , api_version = cls ._get_api_type_and_version (api_type , api_version )
3538
3639 response , _ , api_key = requestor .request (
37- "post" , cls ._get_url ("generations" ), params
40+ "post" , cls ._get_url ("generations" , azure_action = "submit" , api_type = api_type , api_version = api_version ), params
3841 )
3942
43+ if api_type in (util .ApiType .AZURE , util .ApiType .AZURE_AD ):
44+ requestor .api_base = "" # operation_location is a full url
45+ response , _ , api_key = requestor ._poll (
46+ "get" , response .operation_location ,
47+ until = lambda response : response .data ['status' ] in [ 'succeeded' ],
48+ failed = lambda response : response .data ['status' ] in [ 'failed' ]
49+ )
50+
4051 return util .convert_to_openai_object (
4152 response , api_key , api_version , organization
4253 )
@@ -60,12 +71,20 @@ async def acreate(
6071 organization = organization ,
6172 )
6273
63- _ , api_version = cls ._get_api_type_and_version (api_type , api_version )
74+ api_type , api_version = cls ._get_api_type_and_version (api_type , api_version )
6475
6576 response , _ , api_key = await requestor .arequest (
66- "post" , cls ._get_url ("generations" ), params
77+ "post" , cls ._get_url ("generations" , azure_action = "submit" , api_type = api_type , api_version = api_version ), params
6778 )
6879
80+ if api_type in (util .ApiType .AZURE , util .ApiType .AZURE_AD ):
81+ requestor .api_base = "" # operation_location is a full url
82+ response , _ , api_key = await requestor ._apoll (
83+ "get" , response .operation_location ,
84+ until = lambda response : response .data ['status' ] in [ 'succeeded' ],
85+ failed = lambda response : response .data ['status' ] in [ 'failed' ]
86+ )
87+
6988 return util .convert_to_openai_object (
7089 response , api_key , api_version , organization
7190 )
@@ -88,9 +107,9 @@ def _prepare_create_variation(
88107 api_version = api_version ,
89108 organization = organization ,
90109 )
91- _ , api_version = cls ._get_api_type_and_version (api_type , api_version )
110+ api_type , api_version = cls ._get_api_type_and_version (api_type , api_version )
92111
93- url = cls ._get_url ("variations" )
112+ url = cls ._get_url ("variations" , azure_action = None , api_type = api_type , api_version = api_version )
94113
95114 files : List [Any ] = []
96115 for key , value in params .items ():
@@ -109,6 +128,9 @@ def create_variation(
109128 organization = None ,
110129 ** params ,
111130 ):
131+ if api_type in (util .ApiType .AZURE , util .ApiType .AZURE_AD ):
132+ raise error .InvalidAPIType ("Variations are not supported by the Azure OpenAI API yet." )
133+
112134 requestor , url , files = cls ._prepare_create_variation (
113135 image ,
114136 api_key ,
@@ -136,6 +158,9 @@ async def acreate_variation(
136158 organization = None ,
137159 ** params ,
138160 ):
161+ if api_type in (util .ApiType .AZURE , util .ApiType .AZURE_AD ):
162+ raise error .InvalidAPIType ("Variations are not supported by the Azure OpenAI API yet." )
163+
139164 requestor , url , files = cls ._prepare_create_variation (
140165 image ,
141166 api_key ,
@@ -171,9 +196,9 @@ def _prepare_create_edit(
171196 api_version = api_version ,
172197 organization = organization ,
173198 )
174- _ , api_version = cls ._get_api_type_and_version (api_type , api_version )
199+ api_type , api_version = cls ._get_api_type_and_version (api_type , api_version )
175200
176- url = cls ._get_url ("edits" )
201+ url = cls ._get_url ("edits" , azure_action = None , api_type = api_type , api_version = api_version )
177202
178203 files : List [Any ] = []
179204 for key , value in params .items ():
@@ -195,6 +220,9 @@ def create_edit(
195220 organization = None ,
196221 ** params ,
197222 ):
223+ if api_type in (util .ApiType .AZURE , util .ApiType .AZURE_AD ):
224+ raise error .InvalidAPIType ("Edits are not supported by the Azure OpenAI API yet." )
225+
198226 requestor , url , files = cls ._prepare_create_edit (
199227 image ,
200228 mask ,
@@ -224,6 +252,9 @@ async def acreate_edit(
224252 organization = None ,
225253 ** params ,
226254 ):
255+ if api_type in (util .ApiType .AZURE , util .ApiType .AZURE_AD ):
256+ raise error .InvalidAPIType ("Edits are not supported by the Azure OpenAI API yet." )
257+
227258 requestor , url , files = cls ._prepare_create_edit (
228259 image ,
229260 mask ,
0 commit comments