diff --git a/src/ProviderImplementations/OpenAi/OpenAiModelMetadataDirectory.php b/src/ProviderImplementations/OpenAi/OpenAiModelMetadataDirectory.php new file mode 100644 index 0000000..68390da --- /dev/null +++ b/src/ProviderImplementations/OpenAi/OpenAiModelMetadataDirectory.php @@ -0,0 +1,182 @@ +getData(); + if (!isset($responseData['data']) || !$responseData['data']) { + throw new RuntimeException( + 'Unexpected API response: Missing the data key.' + ); + } + + // Unfortunately, the OpenAI API does not return model capabilities, so we have to hardcode them here. + $gptCapabilities = [ + CapabilityEnum::textGeneration(), + CapabilityEnum::chatHistory(), + ]; + $gptOptions = [ + new SupportedOption(ModelConfig::KEY_SYSTEM_INSTRUCTION), + new SupportedOption(ModelConfig::KEY_CANDIDATE_COUNT), + new SupportedOption(ModelConfig::KEY_MAX_TOKENS), + new SupportedOption(ModelConfig::KEY_TEMPERATURE), + new SupportedOption(ModelConfig::KEY_TOP_P), + new SupportedOption(ModelConfig::KEY_STOP_SEQUENCES), + new SupportedOption(ModelConfig::KEY_PRESENCE_PENALTY), + new SupportedOption(ModelConfig::KEY_FREQUENCY_PENALTY), + new SupportedOption(ModelConfig::KEY_LOGPROBS), + new SupportedOption(ModelConfig::KEY_TOP_LOGPROBS), + new SupportedOption(ModelConfig::KEY_OUTPUT_MIME_TYPE, ['text/plain', 'application/json']), + new SupportedOption(ModelConfig::KEY_OUTPUT_SCHEMA), + // TODO: Where to put this as a constant? + new SupportedOption('functionCalling'), + ]; + $gptMultimodalInputOptions = $gptOptions + [ + new SupportedOption( + // TODO: Where to put this as a constant? + 'inputModalities', + [ + [ModalityEnum::text()], + [ModalityEnum::text(), ModalityEnum::image()], + [ModalityEnum::text(), ModalityEnum::image(), ModalityEnum::audio()], + ] + ), + ]; + $gptMultimodalSpeechOutputOptions = $gptMultimodalInputOptions + [ + new SupportedOption( + ModelConfig::KEY_OUTPUT_MODALITIES, + [ + [ModalityEnum::text()], + [ModalityEnum::text(), ModalityEnum::audio()], + ] + ), + ]; + $imageCapabilities = [ + CapabilityEnum::imageGeneration(), + ]; + $dalleImageOptions = [ + new SupportedOption(ModelConfig::KEY_CANDIDATE_COUNT), + new SupportedOption(ModelConfig::KEY_OUTPUT_MIME_TYPE, ['image/png']), + // TODO: Where to put this as a constant? + new SupportedOption('outputFileType', [FileTypeEnum::inline(), FileTypeEnum::remote()]), + // TODO: Where to put this as a constant? + new SupportedOption('imageOrientation', ['square', 'landscape', 'portrait']), + // TODO: Where to put this as a constant? + new SupportedOption('imageAspectRatio', ['1:1', '7:4', '4:7']), + ]; + $gptImageOptions = [ + new SupportedOption(ModelConfig::KEY_CANDIDATE_COUNT), + new SupportedOption(ModelConfig::KEY_OUTPUT_MIME_TYPE, ['image/png', 'image/jpeg', 'image/webp']), + // TODO: Where to put this as a constant? + new SupportedOption('outputFileType', [FileTypeEnum::inline()]), + // TODO: Where to put this as a constant? + new SupportedOption('imageOrientation', ['square', 'landscape', 'portrait']), + // TODO: Where to put this as a constant? + new SupportedOption('imageAspectRatio', ['1:1', '3:2', '2:3']), + ]; + $ttsCapabilities = [ + CapabilityEnum::textToSpeechConversion(), + ]; + $ttsOptions = [ + new SupportedOption(ModelConfig::KEY_OUTPUT_MIME_TYPE, ['audio/mpeg', 'audio/ogg', 'audio/wav']), + // TODO: Where to put this as a constant? + new SupportedOption('voice'), + ]; + + return array_values( + array_map( + static function (array $modelData) use ( + $gptCapabilities, + $gptOptions, + $gptMultimodalInputOptions, + $gptMultimodalSpeechOutputOptions, + $imageCapabilities, + $dalleImageOptions, + $gptImageOptions, + $ttsCapabilities, + $ttsOptions, + ): ModelMetadata { + $modelId = $modelData['id']; + if ( + str_starts_with($modelId, 'dall-e-') || + str_starts_with($modelId, 'gpt-image-') + ) { + $modelCaps = $imageCapabilities; + if (str_starts_with($modelId, 'gpt-image-')) { + $modelOptions = $gptImageOptions; + } else { + $modelOptions = $dalleImageOptions; + } + } elseif ( + str_starts_with($modelId, 'tts-') || + str_contains($modelId, '-tts') + ) { + $modelCaps = $ttsCapabilities; + $modelOptions = $ttsOptions; + } elseif ( + (str_starts_with($modelId, 'gpt-') || str_starts_with($modelId, 'o1-')) + && !str_contains($modelId, '-instruct') + && !str_contains($modelId, '-realtime') + ) { + if (str_starts_with($modelId, 'gpt-4o')) { + $modelCaps = $gptCapabilities; + $modelOptions = $gptMultimodalInputOptions; + // New multimodal output model for audio generation. + if (str_contains($modelId, '-audio')) { + $modelOptions = $gptMultimodalSpeechOutputOptions; + } + } elseif (!str_contains($modelId, '-audio')) { + $modelCaps = $gptCapabilities; + $modelOptions = $gptOptions; + } else { + $modelCaps = []; + $modelOptions = []; + } + } else { + $modelCaps = []; + $modelOptions = []; + } + + return new ModelMetadata( + $modelId, + $modelId, // The OpenAI API does not return a display name. + $modelCaps, + $modelOptions + ); + }, + (array) $responseData['data'] + ) + ); + } +} diff --git a/src/ProviderImplementations/OpenAi/OpenAiProvider.php b/src/ProviderImplementations/OpenAi/OpenAiProvider.php new file mode 100644 index 0000000..cba2c73 --- /dev/null +++ b/src/ProviderImplementations/OpenAi/OpenAiProvider.php @@ -0,0 +1,83 @@ +getSupportedCapabilities(); + foreach ($capabilities as $capability) { + if ($capability->isTextGeneration()) { + return new OpenAiTextGenerationModel($modelMetadata, $providerMetadata); + } + if ($capability->isImageGeneration()) { + // TODO: Implement OpenAiImageGenerationModel. + return new OpenAiImageGenerationModel($modelMetadata, $providerMetadata); + } + if ($capability->isTextToSpeechConversion()) { + // TODO: Implement OpenAiTextToSpeechConversionModel. + return new OpenAiTextToSpeechConversionModel($modelMetadata, $providerMetadata); + } + } + + throw new RuntimeException( + 'Unsupported model capabilities: ' . implode(', ', $capabilities) + ); + } + + /** + * @inheritDoc + */ + protected static function createProviderMetadata(): ProviderMetadata + { + return new ProviderMetadata( + 'openai', + 'OpenAI', + ProviderTypeEnum::cloud() + ); + } + + /** + * @inheritDoc + */ + protected static function createProviderAvailability(): ProviderAvailabilityInterface + { + // Check valid API access by attempting to list models. + return new ListModelsApiBasedProviderAvailability( + static::modelMetadataDirectory() + ); + } + + /** + * @inheritDoc + */ + protected static function createModelMetadataDirectory(): ModelMetadataDirectoryInterface + { + return new OpenAiModelMetadataDirectory(); + } +} diff --git a/src/ProviderImplementations/OpenAi/OpenAiTextGenerationModel.php b/src/ProviderImplementations/OpenAi/OpenAiTextGenerationModel.php new file mode 100644 index 0000000..27d3db4 --- /dev/null +++ b/src/ProviderImplementations/OpenAi/OpenAiTextGenerationModel.php @@ -0,0 +1,25 @@ + Map of model ID to model metadata, effectively for caching. + */ + private ?array $modelMetadataMap = null; + + /** + * @inheritdoc + */ + final public function listModelMetadata(): array + { + $modelsMetadata = $this->getModelMetadataMap(); + return array_values($modelsMetadata); + } + + /** + * @inheritdoc + */ + final public function hasModelMetadata(string $modelId): bool + { + try { + $this->getModelMetadata(); + } catch (InvalidArgumentException $e) { + return false; + } + return true; + } + + /** + * @inheritdoc + */ + final public function getModelMetadata(string $modelId): ModelMetadata + { + $modelsMetadata = $this->getModelMetadataMap(); + if (!isset($modelsMetadata[$modelId])) { + throw new InvalidArgumentException( + sprintf('No model with ID %s was found in the provider', $modelId) + ); + } + return $modelsMetadata[$modelId]; + } + + /** + * Returns the map of model ID to model metadata for all models from the provider. + * + * @since n.e.x.t + * + * @return array Map of model ID to model metadata. + */ + private function getModelMetadataMap(): array + { + if ($this->modelMetadataMap === null) { + $this->modelMetadataMap = $this->sendListModelsRequest(); + } + return $this->modelMetadataMap; + } + + /** + * Sends the API request to list models from the provider and returns the map of model ID to model metadata. + * + * @since n.e.x.t + * + * @return array Map of model ID to model metadata. + */ + abstract protected function sendListModelsRequest(): array; +} diff --git a/src/Providers/AbstractOpenAiCompatibleModelMetadataDirectory.php b/src/Providers/AbstractOpenAiCompatibleModelMetadataDirectory.php new file mode 100644 index 0000000..3417d6c --- /dev/null +++ b/src/Providers/AbstractOpenAiCompatibleModelMetadataDirectory.php @@ -0,0 +1,59 @@ +getHttpTransporter(); + + // Something like this. + $request = $this->createRequest('models'); + $response = $httpTransporter->sendRequest($request); + + $modelsMetadataList = $this->parseResponseToModelMetadataList($response); + + // Parse list to map. + return array_reduce( + $modelsMetadataList, + static function (array $carry, ModelMetadata $metadata) { + $carry[$metadata->getId()] = $metadata; + return $carry; + }, + [] + ); + } + + /** + * Creates a request object for the provider's API. + * + * @since n.e.x.t + * + * @param string $path The API endpoint path, relative to the base URI. + * @return RequestInterface The request object. + */ + abstract protected function createRequest(string $path): RequestInterface; + + /** + * Parses the response from the API endpoint to list models into a list of model metadata objects. + * + * @since n.e.x.t + * + * @param ResponseInterface $response The response from the API endpoint to list models. + * @return list List of model metadata objects. + */ + abstract protected function parseResponseToModelMetadataList(ResponseInterface $response): array; +} diff --git a/src/Providers/AbstractProvider.php b/src/Providers/AbstractProvider.php new file mode 100644 index 0000000..5d8e2ef --- /dev/null +++ b/src/Providers/AbstractProvider.php @@ -0,0 +1,128 @@ + Cache for provider metadata per class. + */ + private static array $metadataCache = []; + + /** + * @var array Cache for provider availability per class. + */ + private static array $availabilityCache = []; + + /** + * @var array Cache for model metadata directory per class. + */ + private static array $modelMetadataDirectoryCache = []; + + /** + * @inheritdoc + */ + final public static function metadata(): ProviderMetadata + { + $className = static::class; + if (!isset(self::$metadataCache[$className])) { + self::$metadataCache[$className] = static::createProviderMetadata(); + } + return self::$metadataCache[$className]; + } + + /** + * @inheritdoc + */ + final public static function model(string $modelId, ?ModelConfig $modelConfig = null): ModelInterface + { + $providerMetadata = static::metadata(); + $modelMetadata = static::modelMetadataDirectory()->getModelMetadata($modelId); + + $model = static::createModel($modelMetadata, $providerMetadata); + if ($modelConfig) { + $model->setConfig($modelConfig); + } + return $model; + } + + /** + * @inheritdoc + */ + final public static function availability(): ProviderAvailabilityInterface + { + $className = static::class; + if (!isset(self::$availabilityCache[$className])) { + self::$availabilityCache[$className] = static::createProviderAvailability(); + } + return self::$availabilityCache[$className]; + } + + /** + * @inheritdoc + */ + final public static function modelMetadataDirectory(): ModelMetadataDirectoryInterface + { + $className = static::class; + if (!isset(self::$modelMetadataDirectoryCache[$className])) { + self::$modelMetadataDirectoryCache[$className] = static::createModelMetadataDirectory(); + } + return self::$modelMetadataDirectoryCache[$className]; + } + + /** + * Creates a model instance based on the given model metadata and provider metadata. + * + * @since n.e.x.t + * + * @param ModelMetadata $modelMetadata The model metadata. + * @param ProviderMetadata $providerMetadata The provider metadata. + * @return ModelInterface The new model instance. + */ + abstract protected static function createModel( + ModelMetadata $modelMetadata, + ProviderMetadata $providerMetadata + ): ModelInterface; + + /** + * Creates the provider metadata instance. + * + * @since n.e.x.t + * + * @return ProviderMetadata The provider metadata. + */ + abstract protected static function createProviderMetadata(): ProviderMetadata; + + /** + * Creates the provider availability instance. + * + * @since n.e.x.t + * + * @return ProviderAvailabilityInterface The provider availability. + */ + abstract protected static function createProviderAvailability(): ProviderAvailabilityInterface; + + /** + * Creates the model metadata directory instance. + * + * @since n.e.x.t + * + * @return ModelMetadataDirectoryInterface The model metadata directory. + */ + abstract protected static function createModelMetadataDirectory(): ModelMetadataDirectoryInterface; +} diff --git a/src/Providers/GenerateTextApiBasedProviderAvailability.php b/src/Providers/GenerateTextApiBasedProviderAvailability.php new file mode 100644 index 0000000..5c2fe8f --- /dev/null +++ b/src/Providers/GenerateTextApiBasedProviderAvailability.php @@ -0,0 +1,70 @@ +model = $model; + } + + /** + * @inheritdoc + */ + public function isConfigured(): bool + { + // Set config to use as few resources as possible for the test. + $modelConfig = ModelConfig::fromArray([ + ModelConfig::KEY_MAX_TOKENS => 1, + ]); + $this->model->setConfig($modelConfig); + + try { + // Attempt to generate text to check if the provider is available. + $this->model->generateTextResult([ + new Message( + MessageRoleEnum::user(), + [new MessagePart('a')] + ), + ]); + return true; + } catch (Exception $e) { + // If an exception occurs, the provider is not available. + return false; + } + } +} diff --git a/src/Providers/ListModelsApiBasedProviderAvailability.php b/src/Providers/ListModelsApiBasedProviderAvailability.php new file mode 100644 index 0000000..491fa64 --- /dev/null +++ b/src/Providers/ListModelsApiBasedProviderAvailability.php @@ -0,0 +1,50 @@ +modelMetadataDirectory = $modelMetadataDirectory; + } + + /** + * @inheritdoc + */ + public function isConfigured(): bool + { + try { + // Attempt to list models to check if the provider is available. + $this->modelMetadataDirectory->listModelMetadata(); + return true; + } catch (Exception $e) { + // If an exception occurs, the provider is not available. + return false; + } + } +} diff --git a/src/Providers/Models/AbstractApiBasedModel.php b/src/Providers/Models/AbstractApiBasedModel.php new file mode 100644 index 0000000..f553019 --- /dev/null +++ b/src/Providers/Models/AbstractApiBasedModel.php @@ -0,0 +1,90 @@ +metadata = $metadata; + $this->providerMetadata = $providerMetadata; + $this->config = ModelConfig::fromArray([]); + } + + /** + * @inheritdoc + */ + final public function metadata(): ModelMetadata + { + return $this->metadata; + } + + /** + * @inheritdoc + */ + final public function providerMetadata(): ProviderMetadata + { + return $this->providerMetadata; + } + + /** + * @inheritdoc + */ + final public function setConfig(ModelConfig $config): void + { + $this->config = $config; + } + + /** + * @inheritdoc + */ + final public function getConfig(): ModelConfig + { + return $this->config; + } +} diff --git a/src/Providers/Models/AbstractOpenAiCompatibleTextGenerationModel.php b/src/Providers/Models/AbstractOpenAiCompatibleTextGenerationModel.php new file mode 100644 index 0000000..738b690 --- /dev/null +++ b/src/Providers/Models/AbstractOpenAiCompatibleTextGenerationModel.php @@ -0,0 +1,473 @@ +getHttpTransporter(); + + $params = $this->prepareGenerateTextParams($prompt); + + // Something like this. + $request = $this->createRequest('chat/completions', $params); + $response = $httpTransporter->sendRequest($request); + + return $this->parseResponseToGenerativeAiResult($response); + } + + /** + * @inheritDoc + */ + final public function streamGenerateTextResult(array $prompt): Generator + { + $params = $this->prepareGenerateTextParams($prompt); + + // TODO: Implement streaming support. + throw new RuntimeException( + 'Streaming is not yet implemented.' + ); + } + + /** + * Prepares the given prompt and the model configuration into parameters for the API request. + * + * @since n.e.x.t + * + * @param list $prompt The prompt to generate text for. Either a single message or a list of messages + * from a chat. + * @return array The parameters for the API request. + */ + protected function prepareGenerateTextParams(array $prompt): array + { + $config = $this->getConfig(); + + $systemInstruction = $config->getSystemInstruction(); + if ($systemInstruction) { + $prompt = $this->mergeSystemInstruction($prompt, $systemInstruction); + } + + $params = [ + 'model' => $this->metadata()->getId(), + 'messages' => $this->prepareMessagesParam($prompt), + ]; + + $outputModalities = $config->getOutputModalities(); + if (is_array($outputModalities)) { + $this->validateOutputModalities($outputModalities); + if (count($outputModalities) > 1) { + $params['modalities'] = $this->prepareOutputModalitiesParam($outputModalities); + } + } + + // TODO: Prepare other parameters based on config. + + return $params; + } + + /** + * Merges the system instruction into the prompt, ensuring that it is the first message. + * + * @since n.e.x.t + * + * @param list $prompt The prompt to merge the system instruction into. + * @param string $systemInstruction The system instruction to merge. + * @return list The updated prompt with the system instruction as the first message. + * @throws InvalidArgumentException If the first message in the prompt is already a system message. + */ + protected function mergeSystemInstruction(array $prompt, string $systemInstruction): array + { + // If the first message is a system message, throw an exception due to a conflict. + if (isset($prompt[0]) && $prompt[0]->getRole() === MessageRoleEnum::system()) { + throw new InvalidArgumentException( + 'The first message in the prompt cannot be a system message when using a system instruction.' + ); + } + + $systemMessage = new SystemMessage([ + new MessagePart($systemInstruction), + ]); + array_unshift($prompt, $systemMessage); + return $prompt; + } + + /** + * Prepares the messages parameter for the API request. + * + * @since n.e.x.t + * + * @param list $messages The messages to prepare. + * @return list> The prepared messages parameter. + */ + protected function prepareMessagesParam(array $messages): array + { + return array_map( + function (Message $message): array { + // Special case: Function response. + $messageParts = $message->getParts(); + if (count($messageParts) === 1 && $messageParts[0]->getType()->isFunctionResponse()) { + $functionResponse = $messageParts[0]->getFunctionResponse(); + if (!$functionResponse) { + // This should be impossible due to class internals, but still needs to be checked. + throw new RuntimeException( + 'The function response typed message part must contain a function response.' + ); + } + return [ + 'role' => 'tool', + 'content' => json_encode($functionResponse->getResponse()), + 'tool_call_id' => $functionResponse->getId(), + ]; + } + return [ + 'role' => $this->getMessageRoleString($message->getRole()), + 'content' => array_filter(array_map( + function (MessagePart $part): ?array { + return $this->getMessagePartContentData($part); + }, + $messageParts + )), + 'tool_calls' => array_filter(array_map( + function (MessagePart $part): ?array { + return $this->getMessagePartToolCallData($part); + }, + $messageParts + )), + ]; + }, + $messages + ); + } + + /** + * Returns the OpenAI API specific role string for the given message role. + * + * @since n.e.x.t + * + * @param MessageRoleEnum $role The message role. + * @return string The role for the API request. + */ + protected function getMessageRoleString(MessageRoleEnum $role): string + { + if ($role === MessageRoleEnum::model()) { + return 'assistant'; + } + if ($role === MessageRoleEnum::system()) { + return 'system'; + } + return 'user'; + } + + /** + * Returns the OpenAI API specific content data for a message part. + * + * @since n.e.x.t + * + * @param MessagePart $part The message part to get the data for. + * @return ?array The data for the message content part, or null if not applicable. + * @throws InvalidArgumentException If the message part type or data is unsupported. + */ + protected function getMessagePartContentData(MessagePart $part): ?array + { + $type = $part->getType(); + if ($type->isText()) { + return [ + 'type' => 'text', + 'text' => $part->getText(), + ]; + } + if ($type->isFile()) { + $file = $part->getFile(); + if (!$file) { + // This should be impossible due to class internals, but still needs to be checked. + throw new RuntimeException( + 'The file typed message part must contain a file.' + ); + } + if ($file->getFileType()->isRemote()) { + if ($file->isImage()) { + return [ + 'type' => 'image_url', + 'image_url' => [ + 'url' => $file->getUrl(), + ], + ]; + } + throw new InvalidArgumentException( + sprintf( + 'Unsupported MIME type "%s" for remote file message part.', + $file->getMimeType() + ) + ); + } + // Else, it is an inline file. + if ($file->isImage()) { + return [ + 'type' => 'image_url', + 'image_url' => [ + 'url' => $file->getBase64Data(), + ], + ]; + } + if ($file->isAudio()) { + return [ + 'type' => 'input_audio', + 'input_audio' => [ + 'data' => $file->getBase64Data(), + 'format' => '', // TODO: Add method to transform MIME type into file extension. + ], + ]; + } + throw new InvalidArgumentException( + sprintf( + 'Unsupported MIME type "%s" for inline file message part.', + $file->getMimeType() + ) + ); + } + if ($type->isFunctionCall()) { + // Skip, as this is separately included. See `getMessagePartToolCallData()`. + return null; + } + if ($type->isFunctionResponse()) { + // Special case: Function response. + throw new InvalidArgumentException( + 'The API only allows a single function response, as the only content of the message.' + ); + } + throw new InvalidArgumentException( + sprintf( + 'Unsupported message part type "%s".', + $type + ) + ); + } + + /** + * Returns the OpenAI API specific tool calls data for a message part. + * + * @since n.e.x.t + * + * @param MessagePart $part The message part to get the data for. + * @return ?array The data for the message tool call part, or null if not applicable. + * @throws InvalidArgumentException If the message part type or data is unsupported. + */ + protected function getMessagePartToolCallData(MessagePart $part): ?array + { + $type = $part->getType(); + if ($type->isFunctionCall()) { + $functionCall = $part->getFunctionCall(); + if (!$functionCall) { + // This should be impossible due to class internals, but still needs to be checked. + throw new RuntimeException( + 'The function call typed message part must contain a function call.' + ); + } + return [ + 'type' => 'function', + 'id' => $functionCall->getId(), + 'function' => [ + 'name' => $functionCall->getName(), + 'arguments' => json_encode($functionCall->getArgs()), + ], + ]; + } + // All other types are handled in `getMessagePartContentData()`. + return null; + } + + /** + * Validates that the given output modalities to ensure that at least one output modality is text. + * + * @since n.e.x.t + * + * @param array $outputModalities The output modalities to validate. + * @throws InvalidArgumentException If no text output modality is present. + */ + protected function validateOutputModalities(array $outputModalities): void + { + // If no output modalities are set, it's fine, as we can assume text. + if (count($outputModalities) === 0) { + return; + } + + foreach ($outputModalities as $modality) { + if ($modality->isText()) { + return; + } + } + + throw new InvalidArgumentException( + 'A text output modality must be present when generating text.' + ); + } + + /** + * Prepares the output modalities parameter for the API request. + * + * @since n.e.x.t + * + * @param array $modalities The modalities to prepare. + * @return list The prepared modalities parameter. + */ + protected function prepareOutputModalitiesParam(array $modalities): array + { + $prepared = []; + foreach ($modalities as $modality) { + if ($modality->isText()) { + $prepared[] = 'text'; + } elseif ($modality->isImage()) { + $prepared[] = 'image'; + } elseif ($modality->isAudio()) { + $prepared[] = 'audio'; + } else { + throw new InvalidArgumentException( + sprintf( + 'Unsupported output modality "%s".', + $modality + ) + ); + } + } + return $prepared; + } + + /** + * Creates a request object for the provider's API. + * + * @since n.e.x.t + * + * @param string $path The API endpoint path, relative to the base URI. + * @param array $params The parameters for the API request. + * @return RequestInterface The request object. + */ + abstract protected function createRequest(string $path, array $params): RequestInterface; + + /** + * Parses the response from the API endpoint to a generative AI result. + * + * @since n.e.x.t + * + * @param ResponseInterface $response The response from the API endpoint. + * @return GenerativeAiResult The parsed generative AI result. + */ + protected function parseResponseToGenerativeAiResult(ResponseInterface $response): GenerativeAiResult + { + $responseData = $response->getData(); + if (!isset($responseData['choices']) || !$responseData['choices']) { + throw new RuntimeException( + 'Unexpected API response: Missing the choices key.' + ); + } + if (!is_array($responseData['choices'])) { + throw new RuntimeException( + 'Unexpected API response: The choices key must contain an array.' + ); + } + + $candidates = []; + foreach ($responseData['choices'] as $choice) { + if (!is_array($choice)) { + throw new RuntimeException( + 'Unexpected API response: Each element in the choices key must be an associative array.' + ); + } + $candidates[] = $this->parseResponseChoiceToCandidate($choice); + } + + $id = $responseData['id'] ?? ''; + $tokenUsage = new TokenUsage( + $responseData['usage']['prompt_tokens'] ?? 0, + $responseData['usage']['completion_tokens'] ?? 0, + $responseData['usage']['total_tokens'] ?? 0 + ); + + // Use any other data from the response as provider metadata. + $providerMetadata = $responseData; + unset($providerMetadata['id'], $providerMetadata['choices'], $providerMetadata['usage']); + + return new GenerativeAiResult( + $id, + $candidates, + $tokenUsage, + $providerMetadata + ); + } + + /** + * Parses a single choice from the API response into a Candidate object. + * + * @since n.e.x.t + * + * @param array $choice The choice data from the API response. + * @return Candidate The parsed candidate. + * @throws RuntimeException If the choice data is invalid. + */ + protected function parseResponseChoiceToCandidate(array $choice): Candidate + { + if (!isset($choice['message']) || !is_array($choice['message'])) { + throw new RuntimeException( + 'Unexpected API response: Each choice must contain a message key with an associative array.' + ); + } + + // TODO: Correctly implement this, as this is not correct - 'message' isn't just a string. + $message = new Message($choice['message']); + + if (!isset($choice['finish_reason']) || !is_string($choice['finish_reason'])) { + throw new RuntimeException( + 'Unexpected API response: Each choice must contain a finish_reason key with a string value.' + ); + } + switch ($choice['finish_reason']) { + case 'stop': + $finishReason = FinishReasonEnum::stop(); + break; + case 'length': + $finishReason = FinishReasonEnum::length(); + break; + case 'content_filter': + $finishReason = FinishReasonEnum::contentFilter(); + break; + case 'tool_calls': + $finishReason = FinishReasonEnum::toolCalls(); + break; + default: + throw new RuntimeException( + sprintf( + 'Unexpected API response: Invalid finish reason "%s".', + $choice['finish_reason'] + ) + ); + } + + return new Candidate($message, $finishReason); + } +} diff --git a/src/Providers/Models/Contracts/ModelInterface.php b/src/Providers/Models/Contracts/ModelInterface.php index e0448e0..1c6adbc 100644 --- a/src/Providers/Models/Contracts/ModelInterface.php +++ b/src/Providers/Models/Contracts/ModelInterface.php @@ -4,6 +4,7 @@ namespace WordPress\AiClient\Providers\Models\Contracts; +use WordPress\AiClient\Providers\DTO\ProviderMetadata; use WordPress\AiClient\Providers\Models\DTO\ModelConfig; use WordPress\AiClient\Providers\Models\DTO\ModelMetadata; @@ -26,6 +27,15 @@ interface ModelInterface */ public function metadata(): ModelMetadata; + /** + * Returns the metadata for the model's provider. + * + * @since n.e.x.t + * + * @return ProviderMetadata The provider metadata. + */ + public function providerMetadata(): ProviderMetadata; + /** * Sets model configuration. * diff --git a/src/Providers/Models/Traits/WithHttpTransporterTrait.php b/src/Providers/Models/Traits/WithHttpTransporterTrait.php new file mode 100644 index 0000000..251d754 --- /dev/null +++ b/src/Providers/Models/Traits/WithHttpTransporterTrait.php @@ -0,0 +1,33 @@ +httpTransporter = $httpTransporter; + } + + public function getHttpTransporter(): HttpTransporterInterface + { + if ($this->httpTransporter === null) { + throw new RuntimeException( + 'HttpTransporterInterface instance not set. Make sure you use the AiClient class for all requests.' + ); + } + return $this->httpTransporter; + } +}