Skip to content

Commit d919cdf

Browse files
authored
Mean pooling for last hidden state (#4004)
* Add mean_pooling parameter to reidentification.py Average the embeddings of all tokens for last_hidden_state * Update README.md * Remove trailing whitespace
1 parent 3b53c42 commit d919cdf

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

tools/accuracy_checker/accuracy_checker/adapters/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ AccuracyChecker supports following set of adapters:
4343
* `joining_method` - method used to join embeddings (optional, supported methods are `sum` and `concatenation`, default - `sum`).
4444
* `target_out` - target output layer name (Optional, if not provided first in the model will be used).
4545
* `keep_shape` - allow keeping initial shape for predicted embedding (Optional, default `False`, it means that model output will be flattenized).
46+
* `mean_pooling` - average the embeddings of all tokens from last_hidden_state (Optional, default `False`)
4647
* `yolo_v2` - converting output of YOLO v2 family models to `DetectionPrediction` representation.
4748
* `classes` - number of detection classes (default 20).
4849
* `anchors` - anchor values provided as comma-separated list or one of precomputed:

tools/accuracy_checker/accuracy_checker/adapters/reidentification.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Copyright (c) 2018-2024 Intel Corporation
2+
Copyright (c) 2018-2025 Intel Corporation
33
44
Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.
@@ -41,7 +41,9 @@ def parameters(cls):
4141
choices=['sum', 'concatenation']
4242
),
4343
'target_out': StringField(optional=True, description='Target output layer name'),
44-
'keep_shape': BoolField(optional=True, default=False, description='keep output embedding shape')
44+
'keep_shape': BoolField(optional=True, default=False, description='keep output embedding shape'),
45+
'mean_pooling': BoolField(optional=True, default=False,
46+
description='Average the embeddings of all tokens for last_hidden_state')
4547
})
4648

4749
return parameters
@@ -54,6 +56,7 @@ def configure(self):
5456
self.joining_method = self.get_value_from_config('joining_method')
5557
self.target_out = self.get_value_from_config('target_out')
5658
self.keep_shape = self.get_value_from_config('keep_shape')
59+
self.mean_pooling = self.get_value_from_config('mean_pooling')
5760

5861
def process(self, raw, identifiers, frame_meta):
5962
"""
@@ -67,6 +70,10 @@ def process(self, raw, identifiers, frame_meta):
6770
raw_prediction = self._extract_predictions(raw, frame_meta)
6871
prediction = raw_prediction[self.output_blob]
6972

73+
if self.mean_pooling:
74+
# Shape: (1, 128, 768) -> (1, 768)
75+
prediction = np.mean(prediction, axis=1)
76+
7077
if self.grn_workaround:
7178
# workaround: GRN layer
7279
prediction = self._grn_layer(prediction)

0 commit comments

Comments
 (0)