Skip to content

Commit 2cf18cc

Browse files
feat: Add support for PostProcessor Sequence
1 parent 0564dd1 commit 2cf18cc

File tree

4 files changed

+72
-3
lines changed

4 files changed

+72
-3
lines changed

src/PostProcessors/PostProcessor.php

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ public static function fromConfig(?array $config): ?self
2828
'ByteLevel' => new ByteLevelPostProcessor($config),
2929
'TemplateProcessing' => new TemplateProcessing($config),
3030
'RobertaProcessing' => new RobertaProcessing($config),
31+
'Sequence' => new PostProcessorSequence($config),
3132
default => throw new \InvalidArgumentException("Unknown post-processor type {$config['type']}"),
3233
};
3334
}
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace Codewithkyrian\Transformers\PostProcessors;
6+
7+
/**
8+
* A post-processor that applies multiple post-processors in sequence.
9+
*/
10+
class PostProcessorSequence extends PostProcessor
11+
{
12+
13+
/**
14+
* List of post-processors to apply.
15+
*/
16+
protected array $processors;
17+
18+
/**
19+
* Creates a new instance of PostProcessorSequence.
20+
*
21+
* @param array $config The configuration array.
22+
* - 'processors' (array): The list of post-processors to apply.
23+
*/
24+
public function __construct(array $config)
25+
{
26+
parent::__construct($config);
27+
28+
$this->processors = array_map(
29+
fn ($processorConfig) => PostProcessor::fromConfig($processorConfig),
30+
$config['processors']
31+
);
32+
}
33+
34+
/**
35+
* Post-process the given tokens.
36+
*
37+
* @param array $tokens The list of tokens for the first sequence.
38+
* @param string[]|null $tokenPair The input tokens for the second sequence in a pair.
39+
* * @param bool $addSpecialTokens Whether to add the special tokens associated with the corresponding model.
40+
*
41+
* @return PostProcessedOutput An array containing the post-processed tokens and token_type_ids.
42+
*/
43+
public function postProcess(array $tokens, ?array $tokenPair = null, bool $addSpecialTokens = true): PostProcessedOutput
44+
{
45+
$tokenTypeIds = null;
46+
47+
foreach ($this->processors as $processor) {
48+
if ($processor instanceof ByteLevelPostProcessor) {
49+
// Special case where we need to pass the tokens_pair to the post-processor
50+
$output = $processor->postProcess($tokens);
51+
$tokens = $output->tokens;
52+
53+
if ($tokenPair !== null) {
54+
$pairOutput = $processor->postProcess($tokenPair);
55+
$tokenPair = $pairOutput->tokens;
56+
}
57+
} else {
58+
$output = $processor->postProcess($tokens, $tokenPair, $addSpecialTokens);
59+
$tokens = $output->tokens;
60+
$tokenTypeIds = $output->tokenTypeIds;
61+
}
62+
}
63+
64+
return new PostProcessedOutput($tokens, $tokenTypeIds);
65+
}
66+
}
67+
68+
?>

src/Utils/AutoConfig.php

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class AutoConfig implements ArrayAccess
1616

1717
protected array $architectures = [];
1818

19-
public int $padTokenId;
19+
public int|array $padTokenId;
2020

2121
protected int $vocabSize;
2222

src/Utils/GenerationConfig.php

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,8 @@ class GenerationConfig implements \ArrayAccess
124124
/** @var bool Whether or not to return a `ModelOutput` instead of a plain tuple. */
125125
public bool $return_dict_in_generate;
126126

127-
/** @var int|null The id of the *padding* token. */
128-
public ?int $pad_token_id;
127+
/** @var int|int[]|null The id of the *padding* token. */
128+
public int|array|null $pad_token_id;
129129

130130
/** @var int|null The id of the *beginning-of-sequence* token. */
131131
public ?int $bos_token_id;

0 commit comments

Comments
 (0)