Skip to content

Commit 57d019e

Browse files
refactor: Move AutoConfig and PretrainedConfig to appropriate directories
1 parent 875ba04 commit 57d019e

28 files changed

+99
-190
lines changed
Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@
22

33
declare(strict_types=1);
44

5-
namespace Codewithkyrian\Transformers\Utils;
6-
7-
use Codewithkyrian\Transformers\Configs\PretrainedConfig;
5+
namespace Codewithkyrian\Transformers\Configs;
86

97
/**
108
* Helper class which is used to instantiate pretrained configs with the `fromPretrained` function.
@@ -18,8 +16,7 @@ public static function fromPretrained(
1816
?string $cacheDir = null,
1917
string $revision = 'main',
2018
?callable $onProgress = null
21-
): PretrainedConfig
22-
{
19+
): PretrainedConfig {
2320
return PretrainedConfig::fromPretrained($modelNameOrPath, $config, $cacheDir, $revision, $onProgress);
2421
}
2522
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
declare(strict_types=1);
44

55

6-
namespace Codewithkyrian\Transformers\Utils;
6+
namespace Codewithkyrian\Transformers\Configs;
77

88
/**
99
* Class representing a configuration for a generation task.

src/Generation/LogitsProcessors/WhisperTimeStampLogitsProcessor.php

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
namespace Codewithkyrian\Transformers\Generation\LogitsProcessors;
77

8+
use Codewithkyrian\Transformers\Configs\GenerationConfig;
89
use Codewithkyrian\Transformers\Tensor\Tensor;
9-
use Codewithkyrian\Transformers\Utils\GenerationConfig;
1010
use function Codewithkyrian\Transformers\Utils\timeUsage;
1111

1212
class WhisperTimeStampLogitsProcessor extends LogitsProcessor
@@ -111,4 +111,4 @@ public function __invoke(array $inputIds, Tensor $logits): Tensor
111111

112112
return $logits;
113113
}
114-
}
114+
}

src/Generation/Samplers/Sampler.php

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,15 @@
55

66
namespace Codewithkyrian\Transformers\Generation\Samplers;
77

8+
use Codewithkyrian\Transformers\Configs\GenerationConfig;
89
use Codewithkyrian\Transformers\Tensor\Tensor;
9-
use Codewithkyrian\Transformers\Utils\GenerationConfig;
1010

1111
/**
1212
* Sampler is a base class for all sampling methods used for text generation.
1313
*/
1414
abstract class Sampler
1515
{
16-
public function __construct(protected GenerationConfig $generationConfig)
17-
{
18-
}
16+
public function __construct(protected GenerationConfig $generationConfig) {}
1917

2018
/**
2119
* Executes the sampler, using the specified logits.

src/Models/Auto/PretrainedMixin.php

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55

66
namespace Codewithkyrian\Transformers\Models\Auto;
77

8+
use Codewithkyrian\Transformers\Configs\AutoConfig;
89
use Codewithkyrian\Transformers\Exceptions\UnsupportedModelTypeException;
910
use Codewithkyrian\Transformers\Models\ModelArchitecture;
1011
use Codewithkyrian\Transformers\Models\Pretrained\PretrainedModel;
11-
use Codewithkyrian\Transformers\Utils\AutoConfig;
1212

1313
/**
1414
* Base class of all AutoModels. Contains the `from_pretrained` function
@@ -47,8 +47,7 @@ public static function fromPretrained(
4747
string $revision = 'main',
4848
?string $modelFilename = null,
4949
?callable $onProgress = null
50-
): PretrainedModel
51-
{
50+
): PretrainedModel {
5251
$config = AutoConfig::fromPretrained($modelNameOrPath, $config, $cacheDir, $revision, $onProgress);
5352

5453
foreach (static::MODEL_CLASS_MAPPINGS as $modelClassMapping) {
@@ -72,7 +71,7 @@ public static function fromPretrained(
7271
}
7372

7473
if (static::BASE_IF_FAIL) {
75-
// echo "Unknown model class for model type {$config->modelType}. Using base class PreTrainedModel.";
74+
// echo "Unknown model class for model type {$config->modelType}. Using base class PreTrainedModel.";
7675

7776
return PretrainedModel::fromPretrained(
7877
modelNameOrPath: $modelNameOrPath,
@@ -109,4 +108,4 @@ protected static function getModelArchitecture($modelClass): ModelArchitecture
109108
default => ModelArchitecture::EncoderOnly,
110109
};
111110
}
112-
}
111+
}

src/Models/ModelArchitecture.php

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,9 @@
44

55
namespace Codewithkyrian\Transformers\Models;
66

7-
use Codewithkyrian\Transformers\Exceptions\MissingModelInputException;
8-
use Codewithkyrian\Transformers\Exceptions\ModelExecutionException;
97
use Codewithkyrian\Transformers\Models\Pretrained\PretrainedModel;
108
use Codewithkyrian\Transformers\Tensor\Tensor;
11-
use Codewithkyrian\Transformers\Utils\GenerationConfig;
9+
1210
use function Codewithkyrian\Transformers\Utils\array_pick;
1311
use function Codewithkyrian\Transformers\Utils\array_pop_key;
1412

@@ -102,8 +100,8 @@ function decoderPrepareInputsForGeneration(PretrainedModel $model, $inputIds, ar
102100
} // Case 3: Past length >= Input IDs
103101
else {
104102
if (
105-
isset($model->config->image_token_index) &&
106-
in_array($model->config->image_token_index, $inputIds->toArray())
103+
isset($model->config['image_token_index']) &&
104+
in_array($model->config['image_token_index'], $inputIds->toArray())
107105
) {
108106
// Support for multiple image tokens
109107
$numImageTokens = $model->config['num_image_tokens'] ?? null;
@@ -181,7 +179,7 @@ protected function seq2seqForward(PretrainedModel $model, array $modelInputs): a
181179
return $this->decoderForward($model, $decoderFeeds, true);
182180
}
183181

184-
protected function createPositionIds(array $modelInputs, array $pastKeyValues = null): Tensor
182+
protected function createPositionIds(array $modelInputs, ?array $pastKeyValues = null): Tensor
185183
{
186184
$inputIds = $modelInputs['input_ids'] ?? null;
187185
$inputsEmbeds = $modelInputs['inputs_embeds'] ?? null;
@@ -211,7 +209,7 @@ protected function createPositionIds(array $modelInputs, array $pastKeyValues =
211209
$positionIds = new Tensor($data, Tensor::int64, $attentionMask->shape());
212210

213211
if ($pastKeyValues) {
214-
$offset = -(($inputIds ?? $inputsEmbeds)->shape()[1]);
212+
$offset = - (($inputIds ?? $inputsEmbeds)->shape()[1]);
215213
$positionIds = $positionIds->slice(null, [$offset, null]); // position_ids[:, -input_ids.shape[1] :]
216214
}
217215

src/Models/Pretrained/BartForConditionalGeneration.php

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,11 @@
22

33
declare(strict_types=1);
44

5-
65
namespace Codewithkyrian\Transformers\Models\Pretrained;
76

7+
use Codewithkyrian\Transformers\Configs\GenerationConfig;
88
use Codewithkyrian\Transformers\Configs\PretrainedConfig;
99
use Codewithkyrian\Transformers\Models\ModelArchitecture;
10-
use Codewithkyrian\Transformers\Utils\AutoConfig;
11-
use Codewithkyrian\Transformers\Utils\GenerationConfig;
1210
use Codewithkyrian\Transformers\Utils\InferenceSession;
1311

1412
/**
@@ -17,13 +15,12 @@
1715
class BartForConditionalGeneration extends BartPretrainedModel
1816
{
1917
public function __construct(
20-
PretrainedConfig $config,
18+
PretrainedConfig $config,
2119
InferenceSession $session,
2220
public InferenceSession $decoderMergedSession,
2321
public ModelArchitecture $modelArchitecture,
2422
public GenerationConfig $generationConfig
25-
)
26-
{
23+
) {
2724
parent::__construct($config, $session, $modelArchitecture);
2825
}
2926
}

src/Models/Pretrained/CLIPVisionModelWithProjection.php

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22

33
declare(strict_types=1);
44

5-
65
namespace Codewithkyrian\Transformers\Models\Pretrained;
76

7+
use Codewithkyrian\Transformers\Configs\PretrainedConfig;
88
use Codewithkyrian\Transformers\Models\ModelArchitecture;
9-
use Codewithkyrian\Transformers\Utils\AutoConfig;
109

1110
/**
1211
* CLIP Vision Model with a projection layer on top (a linear layer on top of the pooled output)
@@ -18,17 +17,15 @@ class CLIPVisionModelWithProjection extends CLIPPretrainedModel
1817
public static function fromPretrained(
1918
string $modelNameOrPath,
2019
bool $quantized = true,
21-
AutoConfig|array $config = null,
20+
array|PretrainedConfig|null $config = null,
2221
?string $cacheDir = null,
2322
?string $token = null,
2423
string $revision = 'main',
2524
?string $modelFilename = null,
2625
ModelArchitecture $modelArchitecture = ModelArchitecture::EncoderOnly,
2726
?callable $onProgress = null
28-
): PretrainedModel
29-
{
30-
// Update default model file name if not provided
27+
): PretrainedModel {
3128
$modelFilename ??= 'vision_model';
3229
return parent::fromPretrained($modelNameOrPath, $quantized, $config, $cacheDir, $token, $revision, $modelFilename, $modelArchitecture, $onProgress);
3330
}
34-
}
31+
}

src/Models/Pretrained/GPT2PretrainedModel.php

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,6 @@
22

33
declare(strict_types=1);
44

5-
65
namespace Codewithkyrian\Transformers\Models\Pretrained;
76

8-
use Codewithkyrian\Transformers\Configs\PretrainedConfig;
9-
use Codewithkyrian\Transformers\Models\ModelArchitecture;
10-
use Codewithkyrian\Transformers\Utils\AutoConfig;
11-
use Codewithkyrian\Transformers\Utils\GenerationConfig;
12-
use Codewithkyrian\Transformers\Utils\InferenceSession;
13-
14-
class GPT2PretrainedModel extends PretrainedModel
15-
{
16-
}
7+
class GPT2PretrainedModel extends PretrainedModel {}

src/Models/Pretrained/GPTBigCodePretrainedModel.php

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,6 @@
22

33
declare(strict_types=1);
44

5-
65
namespace Codewithkyrian\Transformers\Models\Pretrained;
76

8-
use Codewithkyrian\Transformers\Configs\PretrainedConfig;
9-
use Codewithkyrian\Transformers\Models\ModelArchitecture;
10-
use Codewithkyrian\Transformers\Utils\AutoConfig;
11-
use Codewithkyrian\Transformers\Utils\GenerationConfig;
12-
use Codewithkyrian\Transformers\Utils\InferenceSession;
13-
14-
class GPTBigCodePretrainedModel extends PretrainedModel
15-
{
16-
}
7+
class GPTBigCodePretrainedModel extends PretrainedModel {}

0 commit comments

Comments
 (0)