Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 20 additions & 91 deletions src/JsonRpc/MessageFactory.php
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@
final class MessageFactory
{
/**
* Registry of all known notification classes.
* Registry of all known message classes that have methods.
*
* @var array<int, class-string<Notification>>
* @var array<int, class-string<Request|Notification>>
*/
private const REGISTERED_NOTIFICATIONS = [
private const REGISTERED_MESSAGES = [
Schema\Notification\CancelledNotification::class,
Schema\Notification\InitializedNotification::class,
Schema\Notification\LoggingMessageNotification::class,
Expand All @@ -49,14 +49,7 @@ final class MessageFactory
Schema\Notification\ResourceUpdatedNotification::class,
Schema\Notification\RootsListChangedNotification::class,
Schema\Notification\ToolListChangedNotification::class,
];

/**
* Registry of all known request classes.
*
* @var array<int, class-string<Request>>
*/
private const REGISTERED_REQUESTS = [
Schema\Request\CallToolRequest::class,
Schema\Request\CompletionCompleteRequest::class,
Schema\Request\CreateSamplingMessageRequest::class,
Expand All @@ -75,32 +68,24 @@ final class MessageFactory
];

/**
* @param array<int, class-string<Notification>> $registeredNotifications
* @param array<int, class-string<Request>> $registeredRequests
* @param array<int, class-string<Request|Notification>> $registeredMessages
*/
public function __construct(
private readonly array $registeredNotifications,
private readonly array $registeredRequests,
private readonly array $registeredMessages,
) {
foreach ($this->registeredNotifications as $notification) {
if (!is_subclass_of($notification, Notification::class)) {
throw new InvalidArgumentException(\sprintf('Notification classes must extend %s.', Notification::class));
}
}

foreach ($this->registeredRequests as $request) {
if (!is_subclass_of($request, Request::class)) {
throw new InvalidArgumentException(\sprintf('Request classes must extend %s.', Request::class));
foreach ($this->registeredMessages as $messageClass) {
if (!is_subclass_of($messageClass, Request::class) && !is_subclass_of($messageClass, Notification::class)) {
throw new InvalidArgumentException(\sprintf('Message classes must extend %s or %s.', Request::class, Notification::class));
}
}
}

/**
* Creates a new Factory instance with all the protocol's default notifications and requests.
* Creates a new Factory instance with all the protocol's default messages.
*/
public static function make(): self
{
return new self(self::REGISTERED_NOTIFICATIONS, self::REGISTERED_REQUESTS);
return new self(self::REGISTERED_MESSAGES);
}

/**
Expand Down Expand Up @@ -142,10 +127,6 @@ public function create(string $input): array
*/
private function createMessage(array $data): MessageInterface
{
if (!isset($data['jsonrpc']) || MessageInterface::JSONRPC_VERSION !== $data['jsonrpc']) {
throw new InvalidInputMessageException('Invalid or missing "jsonrpc" version.');
}

try {
if (isset($data['error'])) {
return Error::fromArray($data);
Expand All @@ -159,81 +140,29 @@ private function createMessage(array $data): MessageInterface
throw new InvalidInputMessageException('Invalid JSON-RPC message: missing "method", "result", or "error" field.');
}

return isset($data['id']) ? $this->createRequest($data) : $this->createNotification($data);
$messageClass = $this->findMessageClassByMethod($data['method']);

return $messageClass::fromArray($data);
} catch (InvalidArgumentException $e) {
throw new InvalidInputMessageException($e->getMessage(), 0, $e);
}
}

/**
* Creates a Request object by looking up the appropriate class by method name.
*
* @param array<string, mixed> $data
*
* @throws InvalidInputMessageException
*/
private function createRequest(array $data): Request
{
if (!\is_string($data['method'])) {
throw new InvalidInputMessageException('Request "method" must be a string.');
}

$messageClass = $this->findRequestClassByMethod($data['method']);

return $messageClass::fromArray($data);
}

/**
* Creates a Notification object by looking up the appropriate class by method name.
*
* @param array<string, mixed> $data
*
* @throws InvalidInputMessageException
*/
private function createNotification(array $data): Notification
{
if (!\is_string($data['method'])) {
throw new InvalidInputMessageException('Notification "method" must be a string.');
}

$messageClass = $this->findNotificationClassByMethod($data['method']);

return $messageClass::fromArray($data);
}

/**
* Finds the registered request class for a given method name.
*
* @return class-string<Request>
*
* @throws InvalidInputMessageException
*/
private function findRequestClassByMethod(string $method): string
{
foreach ($this->registeredRequests as $requestClass) {
if ($requestClass::getMethod() === $method) {
return $requestClass;
}
}

throw new InvalidInputMessageException(\sprintf('Unknown request method "%s".', $method));
}

/**
* Finds the registered notification class for a given method name.
* Finds the registered message class for a given method name.
*
* @return class-string<Notification>
* @return class-string<Request|Notification>
*
* @throws InvalidInputMessageException
*/
private function findNotificationClassByMethod(string $method): string
private function findMessageClassByMethod(string $method): string
{
foreach ($this->registeredNotifications as $notificationClass) {
if ($notificationClass::getMethod() === $method) {
return $notificationClass;
foreach ($this->registeredMessages as $messageClass) {
if ($messageClass::getMethod() === $method) {
return $messageClass;
}
}

throw new InvalidInputMessageException(\sprintf('Unknown notification method "%s".', $method));
throw new InvalidInputMessageException(\sprintf('Unknown method "%s".', $method));
}
}
28 changes: 13 additions & 15 deletions tests/Unit/JsonRpc/MessageFactoryTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,12 @@ final class MessageFactoryTest extends TestCase

protected function setUp(): void
{
$this->factory = new MessageFactory(
[
CancelledNotification::class,
InitializedNotification::class,
],
[
GetPromptRequest::class,
PingRequest::class,
]
);
$this->factory = new MessageFactory([
CancelledNotification::class,
InitializedNotification::class,
GetPromptRequest::class,
PingRequest::class,
]);
}

public function testCreateRequestWithIntegerId(): void
Expand Down Expand Up @@ -250,7 +246,7 @@ public function testUnknownMethod(): void

$this->assertCount(1, $results);
$this->assertInstanceOf(InvalidInputMessageException::class, $results[0]);
$this->assertStringContainsString('Unknown request method', $results[0]->getMessage());
$this->assertStringContainsString('Unknown method', $results[0]->getMessage());
}

public function testUnknownNotificationMethod(): void
Expand All @@ -261,18 +257,20 @@ public function testUnknownNotificationMethod(): void

$this->assertCount(1, $results);
$this->assertInstanceOf(InvalidInputMessageException::class, $results[0]);
$this->assertStringContainsString('Unknown notification method', $results[0]->getMessage());
$this->assertStringContainsString('Unknown method', $results[0]->getMessage());
}

public function testResponseMissingId(): void
public function testNotificationMethodUsedAsRequest(): void
{
$json = '{"jsonrpc": "2.0", "result": {"status": "ok"}}';
// When a notification method is used with an id, it should still create the notification
// The fromArray validation will handle any issues
$json = '{"jsonrpc": "2.0", "method": "notifications/initialized", "id": 1}';

$results = $this->factory->create($json);

$this->assertCount(1, $results);
// The notification class will reject the id in fromArray validation
$this->assertInstanceOf(InvalidInputMessageException::class, $results[0]);
$this->assertStringContainsString('id', $results[0]->getMessage());
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extra mile 😎 👍


public function testErrorMissingId(): void
Expand Down
1 change: 1 addition & 0 deletions tests/Unit/Server/ProtocolTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ final class ProtocolTest extends TestCase
{
private MockObject&SessionFactoryInterface $sessionFactory;
private MockObject&SessionStoreInterface $sessionStore;
/** @var MockObject&TransportInterface<mixed> */
private MockObject&TransportInterface $transport;

protected function setUp(): void
Expand Down