|
4 | 4 |
|
5 | 5 | namespace Codewithkyrian\Transformers\Models; |
6 | 6 |
|
7 | | -use Codewithkyrian\Transformers\Exceptions\MissingModelInputException; |
8 | | -use Codewithkyrian\Transformers\Exceptions\ModelExecutionException; |
9 | 7 | use Codewithkyrian\Transformers\Models\Pretrained\PretrainedModel; |
10 | 8 | use Codewithkyrian\Transformers\Tensor\Tensor; |
11 | | -use Codewithkyrian\Transformers\Utils\GenerationConfig; |
| 9 | + |
12 | 10 | use function Codewithkyrian\Transformers\Utils\array_pick; |
13 | 11 | use function Codewithkyrian\Transformers\Utils\array_pop_key; |
14 | 12 |
|
@@ -102,8 +100,8 @@ function decoderPrepareInputsForGeneration(PretrainedModel $model, $inputIds, ar |
102 | 100 | } // Case 3: Past length >= Input IDs |
103 | 101 | else { |
104 | 102 | if ( |
105 | | - isset($model->config->image_token_index) && |
106 | | - in_array($model->config->image_token_index, $inputIds->toArray()) |
| 103 | + isset($model->config['image_token_index']) && |
| 104 | + in_array($model->config['image_token_index'], $inputIds->toArray()) |
107 | 105 | ) { |
108 | 106 | // Support for multiple image tokens |
109 | 107 | $numImageTokens = $model->config['num_image_tokens'] ?? null; |
@@ -181,7 +179,7 @@ protected function seq2seqForward(PretrainedModel $model, array $modelInputs): a |
181 | 179 | return $this->decoderForward($model, $decoderFeeds, true); |
182 | 180 | } |
183 | 181 |
|
184 | | - protected function createPositionIds(array $modelInputs, array $pastKeyValues = null): Tensor |
| 182 | + protected function createPositionIds(array $modelInputs, ?array $pastKeyValues = null): Tensor |
185 | 183 | { |
186 | 184 | $inputIds = $modelInputs['input_ids'] ?? null; |
187 | 185 | $inputsEmbeds = $modelInputs['inputs_embeds'] ?? null; |
@@ -211,7 +209,7 @@ protected function createPositionIds(array $modelInputs, array $pastKeyValues = |
211 | 209 | $positionIds = new Tensor($data, Tensor::int64, $attentionMask->shape()); |
212 | 210 |
|
213 | 211 | if ($pastKeyValues) { |
214 | | - $offset = -(($inputIds ?? $inputsEmbeds)->shape()[1]); |
| 212 | + $offset = - (($inputIds ?? $inputsEmbeds)->shape()[1]); |
215 | 213 | $positionIds = $positionIds->slice(null, [$offset, null]); // position_ids[:, -input_ids.shape[1] :] |
216 | 214 | } |
217 | 215 |
|
|
0 commit comments