@@ -6,6 +6,7 @@ public actor AzureOpenAIService {
66 let url : URL
77 let endpoint : OpenAIService . Endpoint
88 let modelName : String
9+ let contextWindow : Int
910 let maxToken : Int
1011 let temperature : Double
1112 let apiKey : String
@@ -15,7 +16,8 @@ public actor AzureOpenAIService {
1516 url: String ? = nil ,
1617 endpoint: OpenAIService . Endpoint ,
1718 modelName: String ,
18- maxToken: Int ? = nil ,
19+ contextWindow: Int ,
20+ maxToken: Int ,
1921 temperature: Double = 0.2 ,
2022 stopWords: [ String ] = [ ] ,
2123 apiKey: String
@@ -31,7 +33,8 @@ public actor AzureOpenAIService {
3133
3234 self . endpoint = endpoint
3335 self . modelName = modelName
34- self . maxToken = maxToken ?? 4096
36+ self . maxToken = maxToken
37+ self . contextWindow = contextWindow
3538 self . temperature = temperature
3639 self . stopWords = stopWords
3740 self . apiKey = apiKey
@@ -46,31 +49,13 @@ extension AzureOpenAIService: CodeCompletionServiceType {
4649 CodeCompletionLogger . logger. logPrompt ( messages. map {
4750 ( $0. content, $0. role. rawValue)
4851 } )
49- return AsyncStream< String> { continuation in
50- let task = Task {
51- let result = try await sendMessages ( messages)
52- try Task . checkCancellation ( )
53- continuation. yield ( result)
54- continuation. finish ( )
55- }
56- continuation. onTermination = { _ in
57- task. cancel ( )
58- }
59- }
52+ let result = try await sendMessages ( messages)
53+ return result. compactMap { $0. choices? . first? . delta? . content } . eraseToStream ( )
6054 case . completion:
6155 let prompt = createPrompt ( from: request)
6256 CodeCompletionLogger . logger. logPrompt ( [ ( prompt, " user " ) ] )
63- return AsyncStream< String> { continuation in
64- let task = Task {
65- let result = try await sendPrompt ( prompt)
66- try Task . checkCancellation ( )
67- continuation. yield ( result)
68- continuation. finish ( )
69- }
70- continuation. onTermination = { _ in
71- task. cancel ( )
72- }
73- }
57+ let result = try await sendPrompt ( prompt)
58+ return result. compactMap { $0. choices? . first? . text } . eraseToStream ( )
7459 }
7560 }
7661}
@@ -82,8 +67,8 @@ extension AzureOpenAIService {
8267
8368 func createMessages( from request: PromptStrategy ) -> [ Message ] {
8469 let strategy = DefaultTruncateStrategy ( maxTokenLimit: max (
85- maxToken / 3 * 2 ,
86- maxToken - 300 - 20
70+ contextWindow / 3 * 2 ,
71+ contextWindow - maxToken - 20
8772 ) )
8873 let prompts = strategy. createTruncatedPrompt ( promptStrategy: request)
8974 return [
@@ -98,13 +83,16 @@ extension AzureOpenAIService {
9883 }
9984 }
10085
101- func sendMessages( _ messages: [ Message ] ) async throws -> String {
86+ func sendMessages(
87+ _ messages: [ Message ]
88+ ) async throws -> ResponseStream < OpenAIService . ChatCompletionsStreamDataChunk > {
10289 let requestBody = OpenAIService . ChatCompletionRequestBody (
10390 model: modelName,
10491 messages: messages,
10592 temperature: temperature,
93+ stream: true ,
10694 stop: stopWords,
107- max_tokens: 300
95+ max_tokens: maxToken
10896 )
10997
11098 var request = URLRequest ( url: url)
@@ -113,47 +101,56 @@ extension AzureOpenAIService {
113101 request. httpBody = try encoder. encode ( requestBody)
114102 request. setValue ( " application/json " , forHTTPHeaderField: " Content-Type " )
115103 request. setValue ( apiKey, forHTTPHeaderField: " api-key " )
116- let ( result, response) = try await URLSession . shared. data ( for: request)
104+ let ( result, response) = try await URLSession . shared. bytes ( for: request)
117105
118106 guard let response = response as? HTTPURLResponse else {
119107 throw CancellationError ( )
120108 }
121109
122110 guard response. statusCode == 200 else {
123- if let error = try ? JSONDecoder ( ) . decode ( APIError . self , from : result ) {
124- throw Error . apiError ( error )
111+ let text = try await result . lines . reduce ( into : " " ) { partialResult , current in
112+ partialResult += current
125113 }
126- throw Error . otherError ( String ( data : result , encoding : . utf8 ) ?? " Unknown Error " )
114+ throw Error . otherError ( text )
127115 }
128116
129- do {
130- let body = try JSONDecoder ( ) . decode (
131- OpenAIService . ChatCompletionResponseBody. self,
132- from: result
133- )
134- return body. choices. first? . message. content ?? " "
135- } catch {
136- dump ( error)
137- throw Error . decodeError ( error)
117+ return ResponseStream ( result: result) {
118+ var text = $0
119+ if text. hasPrefix ( " data: " ) {
120+ text = String ( text. dropFirst ( 6 ) )
121+ }
122+ do {
123+ let chunk = try JSONDecoder ( ) . decode (
124+ OpenAIService . ChatCompletionsStreamDataChunk. self,
125+ from: text. data ( using: . utf8) ?? Data ( )
126+ )
127+ return . init( chunk: chunk, done: chunk. choices? . first? . finish_reason != nil )
128+ } catch {
129+ print ( error)
130+ throw error
131+ }
138132 }
139133 }
140134
141135 func createPrompt( from request: PromptStrategy ) -> String {
142136 let strategy = DefaultTruncateStrategy ( maxTokenLimit: max (
143- maxToken / 3 * 2 ,
144- maxToken - 300 - 20
137+ contextWindow / 3 * 2 ,
138+ contextWindow - maxToken - 20
145139 ) )
146140 let prompts = strategy. createTruncatedPrompt ( promptStrategy: request)
147141 return ( [ request. systemPrompt] + prompts. map ( \. content) ) . joined ( separator: " \n \n " )
148142 }
149143
150- func sendPrompt( _ prompt: String ) async throws -> String {
144+ func sendPrompt(
145+ _ prompt: String
146+ ) async throws -> ResponseStream < OpenAIService . CompletionsStreamDataChunk > {
151147 let requestBody = OpenAIService . CompletionRequestBody (
152148 model: modelName,
153149 prompt: prompt,
154150 temperature: temperature,
151+ stream: true ,
155152 stop: stopWords,
156- max_tokens: 300
153+ max_tokens: maxToken
157154 )
158155
159156 var request = URLRequest ( url: url)
@@ -162,28 +159,34 @@ extension AzureOpenAIService {
162159 request. httpBody = try encoder. encode ( requestBody)
163160 request. setValue ( " application/json " , forHTTPHeaderField: " Content-Type " )
164161 request. setValue ( apiKey, forHTTPHeaderField: " api-key " )
165- let ( result, response) = try await URLSession . shared. data ( for: request)
162+ let ( result, response) = try await URLSession . shared. bytes ( for: request)
166163
167164 guard let response = response as? HTTPURLResponse else {
168165 throw CancellationError ( )
169166 }
170167
171168 guard response. statusCode == 200 else {
172- if let error = try ? JSONDecoder ( ) . decode ( APIError . self , from : result ) {
173- throw Error . apiError ( error )
169+ let text = try await result . lines . reduce ( into : " " ) { partialResult , current in
170+ partialResult += current
174171 }
175- throw Error . otherError ( String ( data : result , encoding : . utf8 ) ?? " Unknown Error " )
172+ throw Error . otherError ( text )
176173 }
177174
178- do {
179- let body = try JSONDecoder ( ) . decode (
180- OpenAIService . CompletionResponseBody. self,
181- from: result
182- )
183- return body. choices. first? . text ?? " "
184- } catch {
185- dump ( error)
186- throw Error . decodeError ( error)
175+ return ResponseStream ( result: result) {
176+ var text = $0
177+ if text. hasPrefix ( " data: " ) {
178+ text = String ( text. dropFirst ( 6 ) )
179+ }
180+ do {
181+ let chunk = try JSONDecoder ( ) . decode (
182+ OpenAIService . CompletionsStreamDataChunk. self,
183+ from: text. data ( using: . utf8) ?? Data ( )
184+ )
185+ return . init( chunk: chunk, done: chunk. choices? . first? . finish_reason != nil )
186+ } catch {
187+ print ( error)
188+ throw error
189+ }
187190 }
188191 }
189192
0 commit comments