Skip to content

Commit 6eecb87

Browse files
committed
Allow message parts to have channel, to separate regular content from thought / reasoning content.
1 parent 3c12da3 commit 6eecb87

File tree

3 files changed

+148
-8
lines changed

3 files changed

+148
-8
lines changed

src/Messages/DTO/MessagePart.php

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
use RuntimeException;
99
use WordPress\AiClient\Common\AbstractDataTransferObject;
1010
use WordPress\AiClient\Files\DTO\File;
11+
use WordPress\AiClient\Messages\Enums\MessagePartChannelEnum;
1112
use WordPress\AiClient\Messages\Enums\MessagePartTypeEnum;
1213
use WordPress\AiClient\Tools\DTO\FunctionCall;
1314
use WordPress\AiClient\Tools\DTO\FunctionResponse;
@@ -25,6 +26,7 @@
2526
* @phpstan-import-type FunctionResponseArrayShape from FunctionResponse
2627
*
2728
* @phpstan-type MessagePartArrayShape array{
29+
* channel: string,
2830
* type: string,
2931
* text?: string,
3032
* file?: FileArrayShape,
@@ -36,11 +38,18 @@
3638
*/
3739
class MessagePart extends AbstractDataTransferObject
3840
{
41+
public const KEY_CHANNEL = 'channel';
3942
public const KEY_TYPE = 'type';
4043
public const KEY_TEXT = 'text';
4144
public const KEY_FILE = 'file';
4245
public const KEY_FUNCTION_CALL = 'functionCall';
4346
public const KEY_FUNCTION_RESPONSE = 'functionResponse';
47+
48+
/**
49+
* @var MessagePartChannelEnum The channel this message part belongs to.
50+
*/
51+
private MessagePartChannelEnum $channel;
52+
4453
/**
4554
* @var MessagePartTypeEnum The type of this message part.
4655
*/
@@ -72,10 +81,13 @@ class MessagePart extends AbstractDataTransferObject
7281
* @since n.e.x.t
7382
*
7483
* @param mixed $content The content of this message part.
84+
* @param MessagePartChannelEnum|null $channel The channel this part belongs to. Defaults to CONTENT.
7585
* @throws InvalidArgumentException If an unsupported content type is provided.
7686
*/
77-
public function __construct($content)
87+
public function __construct($content, ?MessagePartChannelEnum $channel = null)
7888
{
89+
$this->channel = $channel ?? MessagePartChannelEnum::content();
90+
7991
if (is_string($content)) {
8092
$this->type = MessagePartTypeEnum::text();
8193
$this->text = $content;
@@ -100,6 +112,18 @@ public function __construct($content)
100112
}
101113
}
102114

115+
/**
116+
* Gets the channel this message part belongs to.
117+
*
118+
* @since n.e.x.t
119+
*
120+
* @return MessagePartChannelEnum The channel.
121+
*/
122+
public function getChannel(): MessagePartChannelEnum
123+
{
124+
return $this->channel;
125+
}
126+
103127
/**
104128
* Gets the type of this message part.
105129
*
@@ -167,11 +191,18 @@ public function getFunctionResponse(): ?FunctionResponse
167191
*/
168192
public static function getJsonSchema(): array
169193
{
194+
$channelSchema = [
195+
'type' => 'string',
196+
'enum' => MessagePartChannelEnum::getValues(),
197+
'description' => 'The channel this message part belongs to.',
198+
];
199+
170200
return [
171201
'oneOf' => [
172202
[
173203
'type' => 'object',
174204
'properties' => [
205+
self::KEY_CHANNEL => $channelSchema,
175206
self::KEY_TYPE => [
176207
'type' => 'string',
177208
'const' => MessagePartTypeEnum::text()->value,
@@ -187,6 +218,7 @@ public static function getJsonSchema(): array
187218
[
188219
'type' => 'object',
189220
'properties' => [
221+
self::KEY_CHANNEL => $channelSchema,
190222
self::KEY_TYPE => [
191223
'type' => 'string',
192224
'const' => MessagePartTypeEnum::file()->value,
@@ -199,6 +231,7 @@ public static function getJsonSchema(): array
199231
[
200232
'type' => 'object',
201233
'properties' => [
234+
self::KEY_CHANNEL => $channelSchema,
202235
self::KEY_TYPE => [
203236
'type' => 'string',
204237
'const' => MessagePartTypeEnum::functionCall()->value,
@@ -211,6 +244,7 @@ public static function getJsonSchema(): array
211244
[
212245
'type' => 'object',
213246
'properties' => [
247+
self::KEY_CHANNEL => $channelSchema,
214248
self::KEY_TYPE => [
215249
'type' => 'string',
216250
'const' => MessagePartTypeEnum::functionResponse()->value,
@@ -233,7 +267,10 @@ public static function getJsonSchema(): array
233267
*/
234268
public function toArray(): array
235269
{
236-
$data = [self::KEY_TYPE => $this->type->value];
270+
$data = [
271+
self::KEY_CHANNEL => $this->channel->value,
272+
self::KEY_TYPE => $this->type->value,
273+
];
237274

238275
if ($this->text !== null) {
239276
$data[self::KEY_TEXT] = $this->text;
@@ -260,15 +297,26 @@ public function toArray(): array
260297
*/
261298
public static function fromArray(array $array): self
262299
{
300+
if (isset($array[self::KEY_CHANNEL])) {
301+
if (!MessagePartChannelEnum::isValidValue($array[self::KEY_CHANNEL])) {
302+
throw new InvalidArgumentException(
303+
sprintf('Invalid channel value: %s', $array[self::KEY_CHANNEL])
304+
);
305+
}
306+
$channel = MessagePartChannelEnum::from($array[self::KEY_CHANNEL]);
307+
} else {
308+
$channel = null;
309+
}
310+
263311
// Check which properties are set to determine how to construct the MessagePart
264312
if (isset($array[self::KEY_TEXT])) {
265-
return new self($array[self::KEY_TEXT]);
313+
return new self($array[self::KEY_TEXT], $channel);
266314
} elseif (isset($array[self::KEY_FILE])) {
267-
return new self(File::fromArray($array[self::KEY_FILE]));
315+
return new self(File::fromArray($array[self::KEY_FILE]), $channel);
268316
} elseif (isset($array[self::KEY_FUNCTION_CALL])) {
269-
return new self(FunctionCall::fromArray($array[self::KEY_FUNCTION_CALL]));
317+
return new self(FunctionCall::fromArray($array[self::KEY_FUNCTION_CALL]), $channel);
270318
} elseif (isset($array[self::KEY_FUNCTION_RESPONSE])) {
271-
return new self(FunctionResponse::fromArray($array[self::KEY_FUNCTION_RESPONSE]));
319+
return new self(FunctionResponse::fromArray($array[self::KEY_FUNCTION_RESPONSE]), $channel);
272320
} else {
273321
throw new InvalidArgumentException(
274322
'MessagePart requires one of: text, file, functionCall, or functionResponse.'
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace WordPress\AiClient\Messages\Enums;
6+
7+
use WordPress\AiClient\Common\AbstractEnum;
8+
9+
/**
10+
* Enum for message part channels.
11+
*
12+
* @since n.e.x.t
13+
*
14+
* @method static self content() Creates an instance for CONTENT channel.
15+
* @method static self thought() Creates an instance for THOUGHT channel.
16+
* @method bool isContent() Checks if the channel is CONTENT.
17+
* @method bool isThought() Checks if the channel is THOUGHT.
18+
*/
19+
class MessagePartChannelEnum extends AbstractEnum
20+
{
21+
/**
22+
* Regular (primary) content.
23+
*/
24+
public const CONTENT = 'content';
25+
26+
/**
27+
* Model thinking or reasoning.
28+
*/
29+
public const THOUGHT = 'thought';
30+
}

tests/unit/Messages/DTO/MessagePartTest.php

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
use WordPress\AiClient\Files\DTO\File;
1111
use WordPress\AiClient\Files\Enums\FileTypeEnum;
1212
use WordPress\AiClient\Messages\DTO\MessagePart;
13+
use WordPress\AiClient\Messages\Enums\MessagePartChannelEnum;
1314
use WordPress\AiClient\Messages\Enums\MessagePartTypeEnum;
1415
use WordPress\AiClient\Tools\DTO\FunctionCall;
1516
use WordPress\AiClient\Tools\DTO\FunctionResponse;
@@ -31,6 +32,7 @@ public function testCreateWithTextContent(): void
3132

3233
$this->assertEquals(MessagePartTypeEnum::text(), $part->getType());
3334
$this->assertEquals($text, $part->getText());
35+
$this->assertEquals(MessagePartChannelEnum::content(), $part->getChannel());
3436
$this->assertNull($part->getFile());
3537
$this->assertNull($part->getFunctionCall());
3638
$this->assertNull($part->getFunctionResponse());
@@ -47,6 +49,7 @@ public function testCreateWithFileContent(): void
4749
$part = new MessagePart($file);
4850

4951
$this->assertEquals(MessagePartTypeEnum::file(), $part->getType());
52+
$this->assertEquals(MessagePartChannelEnum::content(), $part->getChannel());
5053
$this->assertNull($part->getText());
5154
$this->assertSame($file, $part->getFile());
5255
$this->assertNull($part->getFunctionCall());
@@ -64,6 +67,7 @@ public function testCreateWithFunctionCallContent(): void
6467
$part = new MessagePart($functionCall);
6568

6669
$this->assertEquals(MessagePartTypeEnum::functionCall(), $part->getType());
70+
$this->assertEquals(MessagePartChannelEnum::content(), $part->getChannel());
6771
$this->assertNull($part->getText());
6872
$this->assertNull($part->getFile());
6973
$this->assertSame($functionCall, $part->getFunctionCall());
@@ -81,6 +85,7 @@ public function testCreateWithFunctionResponseContent(): void
8185
$part = new MessagePart($functionResponse);
8286

8387
$this->assertEquals(MessagePartTypeEnum::functionResponse(), $part->getType());
88+
$this->assertEquals(MessagePartChannelEnum::content(), $part->getChannel());
8489
$this->assertNull($part->getText());
8590
$this->assertNull($part->getFile());
8691
$this->assertNull($part->getFunctionCall());
@@ -97,6 +102,7 @@ public function testCreateWithEmptyString(): void
97102
$part = new MessagePart('');
98103

99104
$this->assertEquals(MessagePartTypeEnum::text(), $part->getType());
105+
$this->assertEquals(MessagePartChannelEnum::content(), $part->getChannel());
100106
$this->assertEquals('', $part->getText());
101107
}
102108

@@ -260,8 +266,10 @@ public function testToArrayWithText(): void
260266
$json = $part->toArray();
261267

262268
$this->assertIsArray($json);
269+
$this->assertArrayHasKey(MessagePart::KEY_CHANNEL, $json);
263270
$this->assertArrayHasKey(MessagePart::KEY_TYPE, $json);
264271
$this->assertArrayHasKey(MessagePart::KEY_TEXT, $json);
272+
$this->assertEquals(MessagePartChannelEnum::content()->value, $json[MessagePart::KEY_CHANNEL]);
265273
$this->assertEquals(MessagePartTypeEnum::text()->value, $json[MessagePart::KEY_TYPE]);
266274
$this->assertEquals('Hello, world!', $json[MessagePart::KEY_TEXT]);
267275

@@ -283,8 +291,10 @@ public function testToArrayWithFile(): void
283291
$json = $part->toArray();
284292

285293
$this->assertIsArray($json);
294+
$this->assertArrayHasKey(MessagePart::KEY_CHANNEL, $json);
286295
$this->assertArrayHasKey(MessagePart::KEY_TYPE, $json);
287296
$this->assertArrayHasKey(MessagePart::KEY_FILE, $json);
297+
$this->assertEquals(MessagePartChannelEnum::content()->value, $json[MessagePart::KEY_CHANNEL]);
288298
$this->assertEquals(MessagePartTypeEnum::file()->value, $json[MessagePart::KEY_TYPE]);
289299
$this->assertIsArray($json[MessagePart::KEY_FILE]);
290300
}
@@ -297,13 +307,15 @@ public function testToArrayWithFile(): void
297307
public function testFromArrayWithText(): void
298308
{
299309
$json = [
310+
MessagePart::KEY_CHANNEL => MessagePartChannelEnum::thought()->value,
300311
MessagePart::KEY_TYPE => MessagePartTypeEnum::text()->value,
301312
MessagePart::KEY_TEXT => 'Test message'
302313
];
303314

304315
$part = MessagePart::fromArray($json);
305316

306317
$this->assertEquals(MessagePartTypeEnum::text(), $part->getType());
318+
$this->assertEquals(MessagePartChannelEnum::thought(), $part->getChannel());
307319
$this->assertEquals('Test message', $part->getText());
308320
}
309321

@@ -315,6 +327,7 @@ public function testFromArrayWithText(): void
315327
public function testFromArrayWithFile(): void
316328
{
317329
$json = [
330+
MessagePart::KEY_CHANNEL => MessagePartChannelEnum::content()->value,
318331
MessagePart::KEY_TYPE => MessagePartTypeEnum::file()->value,
319332
MessagePart::KEY_FILE => [
320333
File::KEY_FILE_TYPE => FileTypeEnum::remote()->value,
@@ -326,6 +339,7 @@ public function testFromArrayWithFile(): void
326339
$part = MessagePart::fromArray($json);
327340

328341
$this->assertEquals(MessagePartTypeEnum::file(), $part->getType());
342+
$this->assertEquals(MessagePartChannelEnum::content(), $part->getChannel());
329343
$this->assertInstanceOf(File::class, $part->getFile());
330344
$this->assertEquals('https://example.com/image.jpg', $part->getFile()->getUrl());
331345
}
@@ -338,10 +352,11 @@ public function testFromArrayWithFile(): void
338352
public function testArrayRoundTrip(): void
339353
{
340354
// Test with text
341-
$textPart = new MessagePart('Test text');
355+
$textPart = new MessagePart('Test text', MessagePartChannelEnum::thought());
342356
$textJson = $textPart->toArray();
343357
$restoredText = MessagePart::fromArray($textJson);
344358
$this->assertEquals($textPart->getText(), $restoredText->getText());
359+
$this->assertEquals($textPart->getChannel(), $restoredText->getChannel());
345360

346361
// Test with file
347362
$file = new File('https://example.com/doc.pdf', 'application/pdf');
@@ -350,15 +365,17 @@ public function testArrayRoundTrip(): void
350365
$restoredFile = MessagePart::fromArray($fileJson);
351366
$this->assertEquals($file->getUrl(), $restoredFile->getFile()->getUrl());
352367
$this->assertEquals($file->getMimeType(), $restoredFile->getFile()->getMimeType());
368+
$this->assertEquals($filePart->getChannel(), $restoredFile->getChannel());
353369

354370
// Test with function call
355371
$functionCall = new FunctionCall('id_123', 'getData', ['key' => 'value']);
356-
$funcPart = new MessagePart($functionCall);
372+
$funcPart = new MessagePart($functionCall, MessagePartChannelEnum::thought());
357373
$funcJson = $funcPart->toArray();
358374
$restoredFunc = MessagePart::fromArray($funcJson);
359375
$this->assertEquals($functionCall->getId(), $restoredFunc->getFunctionCall()->getId());
360376
$this->assertEquals($functionCall->getName(), $restoredFunc->getFunctionCall()->getName());
361377
$this->assertEquals($functionCall->getArgs(), $restoredFunc->getFunctionCall()->getArgs());
378+
$this->assertEquals($funcPart->getChannel(), $restoredFunc->getChannel());
362379
}
363380

364381
/**
@@ -375,4 +392,49 @@ public function testImplementsWithArrayTransformationInterface(): void
375392
$part
376393
);
377394
}
395+
396+
/**
397+
* Tests creating MessagePart with different channels.
398+
*
399+
* @return void
400+
*/
401+
public function testCreateWithDifferentChannels(): void
402+
{
403+
// Default channel is CONTENT
404+
$part1 = new MessagePart('Some content');
405+
$this->assertEquals(MessagePartChannelEnum::content(), $part1->getChannel());
406+
$this->assertTrue($part1->getChannel()->isContent());
407+
$this->assertFalse($part1->getChannel()->isThought());
408+
409+
// Explicitly set CONTENT channel
410+
$part2 = new MessagePart('Some content', MessagePartChannelEnum::content());
411+
$this->assertEquals(MessagePartChannelEnum::content(), $part2->getChannel());
412+
$this->assertTrue($part2->getChannel()->isContent());
413+
$this->assertFalse($part2->getChannel()->isThought());
414+
415+
// Explicitly set THOUGHT channel
416+
$part3 = new MessagePart('Some thought', MessagePartChannelEnum::thought());
417+
$this->assertEquals(MessagePartChannelEnum::thought(), $part3->getChannel());
418+
$this->assertFalse($part3->getChannel()->isContent());
419+
$this->assertTrue($part3->getChannel()->isThought());
420+
}
421+
422+
/**
423+
* Tests fromArray with an invalid channel value.
424+
*
425+
* @return void
426+
*/
427+
public function testFromArrayWithInvalidChannel(): void
428+
{
429+
$this->expectException(InvalidArgumentException::class);
430+
$this->expectExceptionMessage('Invalid channel value: invalid_channel');
431+
432+
$json = [
433+
MessagePart::KEY_CHANNEL => 'invalid_channel',
434+
MessagePart::KEY_TYPE => MessagePartTypeEnum::text()->value,
435+
MessagePart::KEY_TEXT => 'Test message'
436+
];
437+
438+
MessagePart::fromArray($json);
439+
}
378440
}

0 commit comments

Comments
 (0)