Skip to content

Implement AiProviderRegistry with comprehensive test suite #38

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: trunk
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 229 additions & 0 deletions src/Providers/AiProviderRegistry.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
<?php

declare(strict_types=1);

namespace WordPress\AiClient\Providers;

use InvalidArgumentException;
use WordPress\AiClient\Providers\Contracts\ModelMetadataDirectoryInterface;
use WordPress\AiClient\Providers\Contracts\ProviderAvailabilityInterface;
use WordPress\AiClient\Providers\Contracts\ProviderInterface;
use WordPress\AiClient\Providers\DTO\ProviderMetadata;
use WordPress\AiClient\Providers\DTO\ProviderModelsMetadata;
use WordPress\AiClient\Providers\Models\Contracts\ModelInterface;
use WordPress\AiClient\Providers\Models\DTO\ModelConfig;
use WordPress\AiClient\Providers\Models\DTO\ModelMetadata;
use WordPress\AiClient\Providers\Models\DTO\ModelRequirements;

/**
* Registry for managing AI providers and their models.
*
* This class provides a centralized way to register AI providers, discover
* their capabilities, and find suitable models based on requirements.
*
* @since n.e.x.t
*/
class AiProviderRegistry
{
/**
* @var array<string, string> Mapping of provider IDs to class names.
*/
private array $providerClassNames = [];


/**
* Registers a provider class with the registry.
*
* @since n.e.x.t
*
* @param string $className The fully qualified provider class name.
* @throws InvalidArgumentException If the class doesn't exist or implement required interface.
*/
public function registerProvider(string $className): void
{
if (!class_exists($className)) {
throw new InvalidArgumentException(
sprintf('Provider class does not exist: %s', $className)
);
}

// Validate that class implements ProviderInterface
if (!is_subclass_of($className, ProviderInterface::class)) {
throw new InvalidArgumentException(
sprintf('Provider class must implement %s: %s', ProviderInterface::class, $className)
);
}

// Get provider metadata to extract ID (using static method from interface)
/** @var class-string<ProviderInterface> $className */
$metadata = $className::metadata();

if (!$metadata instanceof ProviderMetadata) {
throw new InvalidArgumentException(
sprintf('Provider must return ProviderMetadata from metadata() method: %s', $className)
);
}

$this->providerClassNames[$metadata->getId()] = $className;
}

/**
* Checks if a provider is registered.
*
* @since n.e.x.t
*
* @param string $idOrClassName The provider ID or class name to check.
* @return bool True if the provider is registered.
*/
public function hasProvider(string $idOrClassName): bool
{
return isset($this->providerClassNames[$idOrClassName]) ||
in_array($idOrClassName, $this->providerClassNames, true);
}

/**
* Gets the class name for a registered provider.
*
* @since n.e.x.t
*
* @param string $id The provider ID.
* @return string The provider class name.
* @throws InvalidArgumentException If the provider is not registered.
*/
public function getProviderClassName(string $id): string
{
if (!isset($this->providerClassNames[$id])) {
throw new InvalidArgumentException(
sprintf('Provider not registered: %s', $id)
);
}

return $this->providerClassNames[$id];
}

/**
* Checks if a provider is properly configured.
*
* @since n.e.x.t
*
* @param string $idOrClassName The provider ID or class name.
* @return bool True if the provider is configured and ready to use.
*/
public function isProviderConfigured(string $idOrClassName): bool
{
try {
$className = $this->resolveProviderClassName($idOrClassName);

// Use static method from ProviderInterface
/** @var class-string<ProviderInterface> $className */
$availability = $className::availability();

return $availability->isConfigured();
} catch (InvalidArgumentException $e) {
return false;
}
}

/**
* Finds models across all providers that support the given requirements.
*
* @since n.e.x.t
*
* @param ModelRequirements $modelRequirements The requirements to match against.
* @return list<ProviderModelsMetadata> List of provider models metadata that match requirements.
*/
public function findModelsMetadataForSupport(ModelRequirements $modelRequirements): array
{
$results = [];

foreach ($this->providerClassNames as $providerId => $className) {
$providerResults = $this->findProviderModelsMetadataForSupport($providerId, $modelRequirements);
if (!empty($providerResults)) {
// Use static method from ProviderInterface
/** @var class-string<ProviderInterface> $className */
$providerMetadata = $className::metadata();

$results[] = new ProviderModelsMetadata(
$providerMetadata,
$providerResults
);
}
}

return $results;
}

/**
* Finds models within a specific provider that support the given requirements.
*
* @since n.e.x.t
*
* @param string $idOrClassName The provider ID or class name.
* @param ModelRequirements $modelRequirements The requirements to match against.
* @return list<ModelMetadata> List of model metadata that match requirements.
*/
public function findProviderModelsMetadataForSupport(
string $idOrClassName,
ModelRequirements $modelRequirements
): array {
$className = $this->resolveProviderClassName($idOrClassName);

// Use static method from ProviderInterface
/** @var class-string<ProviderInterface> $className */
$modelMetadataDirectory = $className::modelMetadataDirectory();

// Filter models that meet requirements
$matchingModels = [];
foreach ($modelMetadataDirectory->listModelMetadata() as $modelMetadata) {
if ($modelMetadata->meetsRequirements($modelRequirements)) {
$matchingModels[] = $modelMetadata;
}
}

return $matchingModels;
}

/**
* Gets a configured model instance from a provider.
*
* @since n.e.x.t
*
* @param string $idOrClassName The provider ID or class name.
* @param string $modelId The model identifier.
* @param ModelConfig|null $modelConfig The model configuration.
* @return ModelInterface The configured model instance.
* @throws InvalidArgumentException If provider or model is not found.
*/
public function getProviderModel(
string $idOrClassName,
string $modelId,
?ModelConfig $modelConfig = null
): ModelInterface {
$className = $this->resolveProviderClassName($idOrClassName);

// Use static method from ProviderInterface
/** @var class-string<ProviderInterface> $className */
return $className::model($modelId, $modelConfig);
}

/**
* Gets the class name for a registered provider (handles both ID and class name input).
*
* @param string $idOrClassName The provider ID or class name.
* @return string The provider class name.
* @throws InvalidArgumentException If provider is not registered.
*/
private function resolveProviderClassName(string $idOrClassName): string
{
// Handle both ID and class name
$className = $this->providerClassNames[$idOrClassName] ?? $idOrClassName;

if (!$this->hasProvider($idOrClassName)) {
throw new InvalidArgumentException(
sprintf('Provider not registered: %s', $idOrClassName)
);
}

return $className;
}
}
23 changes: 11 additions & 12 deletions src/Providers/Models/Contracts/ModelInterface.php
Original file line number Diff line number Diff line change
Expand Up @@ -10,38 +10,37 @@
/**
* Interface for AI models.
*
* Models represent specific AI models from providers and define
* their capabilities, configuration, and execution methods.
* All models must implement this interface to provide
* metadata access and configuration capabilities.
*
* @since n.e.x.t
*/
interface ModelInterface
{
/**
* Gets model metadata.
* Gets the model's metadata.
*
* @since n.e.x.t
*
* @return ModelMetadata Model metadata.
* @return ModelMetadata The model metadata.
*/
public function metadata(): ModelMetadata;
public function getMetadata(): ModelMetadata;

/**
* Sets model configuration.
* Gets the current model configuration.
*
* @since n.e.x.t
*
* @param ModelConfig $config Model configuration.
* @return void
* @return ModelConfig The model configuration.
*/
public function setConfig(ModelConfig $config): void;
public function getConfig(): ModelConfig;

/**
* Gets model configuration.
* Sets the model configuration.
*
* @since n.e.x.t
*
* @return ModelConfig Current model configuration.
* @param ModelConfig $config The model configuration.
*/
public function getConfig(): ModelConfig;
public function setConfig(ModelConfig $config): void;
}
71 changes: 35 additions & 36 deletions src/Providers/Models/DTO/ModelMetadata.php
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,41 @@ public function getSupportedOptions(): array
return $this->supportedOptions;
}

/**
* Checks whether this model meets the specified requirements.
*
* @since n.e.x.t
*
* @param ModelRequirements $requirements The requirements to check against.
* @return bool True if the model meets all requirements, false otherwise.
*/
public function meetsRequirements(ModelRequirements $requirements): bool
{
// Check if all required capabilities are supported using map lookup
foreach ($requirements->getRequiredCapabilities() as $requiredCapability) {
if (!isset($this->capabilitiesMap[$requiredCapability->value])) {
return false;
}
}

// Check if all required options are supported with the specified values
foreach ($requirements->getRequiredOptions() as $requiredOption) {
// Use map lookup instead of linear search
if (!isset($this->optionsMap[$requiredOption->getName()])) {
return false;
}

$supportedOption = $this->optionsMap[$requiredOption->getName()];

// Check if the required value is supported by this option
if (!$supportedOption->isSupportedValue($requiredOption->getValue())) {
return false;
}
}

return true;
}

/**
* {@inheritDoc}
*
Expand Down Expand Up @@ -209,42 +244,6 @@ public function toArray(): array
];
}

/**
* Checks whether this model meets the specified requirements.
*
* @since n.e.x.t
*
* @param ModelRequirements $requirements The requirements to check against.
* @return bool True if the model meets all requirements, false otherwise.
*/
public function meetsRequirements(ModelRequirements $requirements): bool
{
// Check if all required capabilities are supported using map lookup
foreach ($requirements->getRequiredCapabilities() as $requiredCapability) {
if (!isset($this->capabilitiesMap[$requiredCapability->value])) {
return false;
}
}

// Check if all required options are supported with the specified values
foreach ($requirements->getRequiredOptions() as $requiredOption) {
// Use map lookup instead of linear search
if (!isset($this->optionsMap[$requiredOption->getName()])) {
return false;
}

$supportedOption = $this->optionsMap[$requiredOption->getName()];

// Check if the required value is supported by this option
if (!$supportedOption->isSupportedValue($requiredOption->getValue())) {
return false;
}
}

return true;
}


/**
* {@inheritDoc}
*
Expand Down
9 changes: 1 addition & 8 deletions tests/unit/Files/DTO/FileTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,6 @@ public function testMimeTypeMethods(): void
$this->assertFalse($file->isImage());
$this->assertFalse($file->isAudio());
$this->assertFalse($file->isText());
$this->assertTrue($file->isMimeType('video'));
$this->assertFalse($file->isMimeType('image'));
$this->assertFalse($file->isMimeType('audio'));
$this->assertFalse($file->isMimeType('text'));
}

/**
Expand All @@ -237,10 +233,7 @@ public function testJsonSchema(): void
$this->assertArrayHasKey(File::KEY_FILE_TYPE, $remoteSchema['properties']);
$this->assertArrayHasKey(File::KEY_MIME_TYPE, $remoteSchema['properties']);
$this->assertArrayHasKey(File::KEY_URL, $remoteSchema['properties']);
$this->assertEquals(
[File::KEY_FILE_TYPE, File::KEY_MIME_TYPE, File::KEY_URL],
$remoteSchema['required']
);
$this->assertEquals([File::KEY_FILE_TYPE, File::KEY_MIME_TYPE, File::KEY_URL], $remoteSchema['required']);

// Check inline file schema
$inlineSchema = $schema['oneOf'][1];
Expand Down
Loading