Skip to content

Commit a2d26a5

Browse files
authored
Add eos/last_token pooling (#1335)
1 parent a5847c9 commit a2d26a5

File tree

1 file changed

+19
-9
lines changed

1 file changed

+19
-9
lines changed

src/pipelines.js

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,7 +1229,7 @@ export class ZeroShotClassificationPipeline extends (/** @type {new (options: Te
12291229

12301230
/**
12311231
* @typedef {Object} FeatureExtractionPipelineOptions Parameters specific to feature extraction pipelines.
1232-
* @property {'none'|'mean'|'cls'} [pooling="none"] The pooling method to use.
1232+
* @property {'none'|'mean'|'cls'|'first_token'|'eos'|'last_token'} [pooling="none"] The pooling method to use.
12331233
* @property {boolean} [normalize=false] Whether or not to normalize the embeddings in the last dimension.
12341234
* @property {boolean} [quantize=false] Whether or not to quantize the embeddings.
12351235
* @property {'binary'|'ubinary'} [precision='binary'] The precision to use for quantization.
@@ -1322,14 +1322,24 @@ export class FeatureExtractionPipeline extends (/** @type {new (options: TextPip
13221322

13231323
/** @type {Tensor} */
13241324
let result = outputs.last_hidden_state ?? outputs.logits ?? outputs.token_embeddings;
1325-
if (pooling === 'none') {
1326-
// Skip pooling
1327-
} else if (pooling === 'mean') {
1328-
result = mean_pooling(result, model_inputs.attention_mask);
1329-
} else if (pooling === 'cls') {
1330-
result = result.slice(null, 0);
1331-
} else {
1332-
throw Error(`Pooling method '${pooling}' not supported.`);
1325+
1326+
switch (pooling) {
1327+
case 'none':
1328+
// Skip pooling
1329+
break;
1330+
case 'mean':
1331+
result = mean_pooling(result, model_inputs.attention_mask);
1332+
break;
1333+
case 'first_token':
1334+
case 'cls':
1335+
result = result.slice(null, 0);
1336+
break;
1337+
case 'last_token':
1338+
case 'eos':
1339+
result = result.slice(null, -1);
1340+
break;
1341+
default:
1342+
throw Error(`Pooling method '${pooling}' not supported.`);
13331343
}
13341344

13351345
if (normalize) {

0 commit comments

Comments
 (0)