|
| 1 | +import CopilotForXcodeKit |
| 2 | +import Foundation |
| 3 | +import Fundamental |
| 4 | + |
| 5 | +public actor AnthropicService { |
| 6 | + let url: URL |
| 7 | + let modelName: String |
| 8 | + let contextWindow: Int |
| 9 | + let maxToken: Int |
| 10 | + let temperature: Double |
| 11 | + let apiKey: String |
| 12 | + let stopWords: [String] |
| 13 | + |
| 14 | + init( |
| 15 | + url: String? = nil, |
| 16 | + modelName: String, |
| 17 | + contextWindow: Int, |
| 18 | + maxToken: Int, |
| 19 | + temperature: Double = 0.2, |
| 20 | + stopWords: [String] = [], |
| 21 | + apiKey: String |
| 22 | + ) { |
| 23 | + self.url = url.flatMap(URL.init(string:)) ?? |
| 24 | + URL(string: "https://api.anthropic.com/v1/messages")! |
| 25 | + self.modelName = modelName |
| 26 | + self.maxToken = maxToken |
| 27 | + self.temperature = temperature |
| 28 | + self.apiKey = apiKey |
| 29 | + self.stopWords = stopWords |
| 30 | + self.contextWindow = contextWindow |
| 31 | + } |
| 32 | + |
| 33 | + public enum Models: String, CaseIterable { |
| 34 | + case claude3Opus = "claude-3-opus-latest" |
| 35 | + case claude35Sonnet = "claude-3-5-sonnet-latest" |
| 36 | + case claude35Haiku = "claude-3-5-haiku-latest" |
| 37 | + |
| 38 | + public var maxToken: Int { |
| 39 | + switch self { |
| 40 | + case .claude3Opus: return 200_000 |
| 41 | + case .claude35Sonnet: return 200_000 |
| 42 | + case .claude35Haiku: return 200_000 |
| 43 | + } |
| 44 | + } |
| 45 | + } |
| 46 | +} |
| 47 | + |
| 48 | +// MARK: - CodeCompletionServiceType Implementation |
| 49 | + |
| 50 | +extension AnthropicService: CodeCompletionServiceType { |
| 51 | + func getCompletion(_ request: PromptStrategy) async throws -> AsyncStream<String> { |
| 52 | + let (messages, systemPrompt) = createMessages(from: request) |
| 53 | + CodeCompletionLogger.logger.logPrompt(messages.map { |
| 54 | + ($0.content, $0.role.rawValue) |
| 55 | + }) |
| 56 | + let result = try await sendMessages(messages, systemPrompt: systemPrompt) |
| 57 | + return result.compactMap { $0.delta?.text }.eraseToStream() |
| 58 | + } |
| 59 | +} |
| 60 | + |
| 61 | +// MARK: - Message Structure and Request Handling |
| 62 | + |
| 63 | +extension AnthropicService { |
| 64 | + public struct Message: Codable { |
| 65 | + public enum Role: String, Codable { |
| 66 | + case user |
| 67 | + case assistant |
| 68 | + } |
| 69 | + |
| 70 | + var role: Role |
| 71 | + var content: String |
| 72 | + } |
| 73 | + |
| 74 | + struct MessageRequestBody: Codable { |
| 75 | + var model: String |
| 76 | + var messages: [Message] |
| 77 | + var system: String? |
| 78 | + var max_tokens: Int |
| 79 | + var temperature: Double |
| 80 | + var stream: Bool = true |
| 81 | + var stop_sequences: [String]? |
| 82 | + |
| 83 | + enum CodingKeys: String, CodingKey { |
| 84 | + case model |
| 85 | + case messages |
| 86 | + case system |
| 87 | + case max_tokens |
| 88 | + case temperature |
| 89 | + case stream |
| 90 | + case stop_sequences |
| 91 | + } |
| 92 | + } |
| 93 | + |
| 94 | + func createMessages(from request: PromptStrategy) -> (messages: [Message], system: String?) { |
| 95 | + let strategy = DefaultTruncateStrategy(maxTokenLimit: max( |
| 96 | + contextWindow / 3 * 2, |
| 97 | + contextWindow - maxToken - 20 |
| 98 | + )) |
| 99 | + let prompts = strategy.createTruncatedPrompt(promptStrategy: request) |
| 100 | + |
| 101 | + let systemPrompt = request.systemPrompt |
| 102 | + |
| 103 | + let messages = prompts.map { prompt in |
| 104 | + Message( |
| 105 | + role: prompt.role == .user ? .user : .assistant, |
| 106 | + content: prompt.content |
| 107 | + ) |
| 108 | + } |
| 109 | + |
| 110 | + return (messages: messages, system: systemPrompt) |
| 111 | + } |
| 112 | + |
| 113 | + func sendMessages(_ messages: [Message], systemPrompt: String?) async throws -> ResponseStream<StreamResponse> { |
| 114 | + let validStopSequences = stopWords.filter { |
| 115 | + !$0.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty |
| 116 | + } |
| 117 | + |
| 118 | + let requestBody = MessageRequestBody( |
| 119 | + model: modelName, |
| 120 | + messages: messages, |
| 121 | + system: systemPrompt, |
| 122 | + max_tokens: maxToken, |
| 123 | + temperature: temperature, |
| 124 | + stop_sequences: validStopSequences |
| 125 | + ) |
| 126 | + |
| 127 | + var request = URLRequest(url: url) |
| 128 | + request.httpMethod = "POST" |
| 129 | + request.setValue("application/json", forHTTPHeaderField: "content-type") |
| 130 | + request.setValue("\(apiKey)", forHTTPHeaderField: "x-api-key") |
| 131 | + request.setValue("2023-06-01", forHTTPHeaderField: "anthropic-version") |
| 132 | + |
| 133 | + let encoder = JSONEncoder() |
| 134 | + request.httpBody = try encoder.encode(requestBody) |
| 135 | + |
| 136 | + let (result, response) = try await URLSession.shared.bytes(for: request) |
| 137 | + |
| 138 | + guard let httpResponse = response as? HTTPURLResponse else { |
| 139 | + throw CancellationError() |
| 140 | + } |
| 141 | + |
| 142 | + guard httpResponse.statusCode == 200 else { |
| 143 | + let text = try await result.lines.reduce(into: "") { partialResult, current in |
| 144 | + partialResult += current |
| 145 | + } |
| 146 | + throw Error.otherError(text) |
| 147 | + } |
| 148 | + |
| 149 | + return ResponseStream(result: result) { |
| 150 | + var text = $0 |
| 151 | + |
| 152 | + if text.hasPrefix("event: ") { |
| 153 | + return .init(chunk: StreamResponse(), done: false) |
| 154 | + } |
| 155 | + |
| 156 | + if text.hasPrefix("data: ") { |
| 157 | + text = String(text.dropFirst(6)) |
| 158 | + |
| 159 | + guard !text.trimmingCharacters(in: .whitespaces).isEmpty else { |
| 160 | + return .init(chunk: StreamResponse(), done: false) |
| 161 | + } |
| 162 | + |
| 163 | + do { |
| 164 | + let chunk = try JSONDecoder().decode( |
| 165 | + StreamResponse.self, |
| 166 | + from: text.data(using: .utf8) ?? Data() |
| 167 | + ) |
| 168 | + return .init( |
| 169 | + chunk: chunk, |
| 170 | + done: chunk.type == "message_stop" |
| 171 | + ) |
| 172 | + } catch { |
| 173 | + print("Error decoding chunk: \(error)") |
| 174 | + throw error |
| 175 | + } |
| 176 | + } |
| 177 | + |
| 178 | + return .init(chunk: StreamResponse(), done: false) |
| 179 | + } |
| 180 | + } |
| 181 | +} |
| 182 | + |
| 183 | +// MARK: - API Response Structures |
| 184 | + |
| 185 | +extension AnthropicService { |
| 186 | + struct StreamResponse: Decodable { |
| 187 | + var type: String? |
| 188 | + var delta: Delta? |
| 189 | + var index: Int? |
| 190 | + var content: [Content]? |
| 191 | + |
| 192 | + struct Delta: Decodable { |
| 193 | + var text: String? |
| 194 | + var type: String? |
| 195 | + } |
| 196 | + |
| 197 | + struct Content: Decodable { |
| 198 | + var text: String |
| 199 | + var type: String |
| 200 | + } |
| 201 | + } |
| 202 | + |
| 203 | + struct APIError: Decodable { |
| 204 | + var type: String |
| 205 | + var message: String |
| 206 | + var code: String? |
| 207 | + } |
| 208 | + |
| 209 | + enum Error: Swift.Error, LocalizedError { |
| 210 | + case decodeError(Swift.Error) |
| 211 | + case apiError(APIError) |
| 212 | + case otherError(String) |
| 213 | + |
| 214 | + var errorDescription: String? { |
| 215 | + switch self { |
| 216 | + case let .decodeError(error): |
| 217 | + return error.localizedDescription |
| 218 | + case let .apiError(error): |
| 219 | + return error.message |
| 220 | + case let .otherError(message): |
| 221 | + return message |
| 222 | + } |
| 223 | + } |
| 224 | + } |
| 225 | +} |
| 226 | + |
| 227 | +// MARK: - Helper Methods |
| 228 | + |
| 229 | +extension AnthropicService { |
| 230 | + func validateResponse(_ response: HTTPURLResponse) throws { |
| 231 | + guard (200 ... 299).contains(response.statusCode) else { |
| 232 | + throw Error.otherError("HTTP Error: \(response.statusCode)") |
| 233 | + } |
| 234 | + } |
| 235 | +} |
0 commit comments