Skip to content

Commit bf080f7

Browse files
feat: Add support for eos and last_token pooling in FeatureExtractionPipeline
1 parent 024ed11 commit bf080f7

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

src/Pipelines/FeatureExtractionPipeline.php

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
declare(strict_types=1);
44

5-
65
namespace Codewithkyrian\Transformers\Pipelines;
76

87
use Codewithkyrian\Transformers\Tensor\Tensor;
@@ -80,15 +79,23 @@ public function __invoke(array|string $inputs, ...$args): array
8079
case 'none':
8180
// No pooling, return the full tensor
8281
break;
82+
8383
case 'mean':
8484
$result = $result->meanPooling($modelInputs["attention_mask"]);
8585
break;
86+
87+
case 'first_token':
8688
case 'cls':
8789
$result = $result->slice(null, 0);
8890
break;
8991

92+
case 'last_token':
93+
case 'eos':
94+
$result = $result->slice(null, -1);
95+
break;
96+
9097
default:
91-
throw new \Error("Pooling method not supported. Please use 'mean', 'cls', or 'none'.");
98+
throw new \Error("Pooling method not supported. Please use 'mean', 'cls', 'first_token', 'last_token', or 'none'.");
9299
}
93100

94101
if ($normalize) {
@@ -97,4 +104,4 @@ public function __invoke(array|string $inputs, ...$args): array
97104

98105
return $result->toArray();
99106
}
100-
}
107+
}

0 commit comments

Comments
 (0)