Skip to content

Commit fcf7387

Browse files
committed
Fix testIndexingPressure
1 parent 96f4037 commit fcf7387

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@
102102
import static org.mockito.ArgumentMatchers.anyInt;
103103
import static org.mockito.ArgumentMatchers.anyLong;
104104
import static org.mockito.ArgumentMatchers.anyString;
105+
import static org.mockito.ArgumentMatchers.eq;
106+
import static org.mockito.ArgumentMatchers.longThat;
105107
import static org.mockito.Mockito.any;
106108
import static org.mockito.Mockito.doAnswer;
107109
import static org.mockito.Mockito.mock;
@@ -498,9 +500,6 @@ public void testIndexingPressure() throws Exception {
498500
final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure(Settings.EMPTY);
499501
final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING);
500502
final StaticModel denseModel = StaticModel.createRandomInstance(TaskType.TEXT_EMBEDDING);
501-
final int denseModelEmbeddingBytes = denseModel.getServiceSettings()
502-
.elementType()
503-
.getNumBytes(denseModel.getServiceSettings().dimensions());
504503
final Function<XContentBuilder, Long> bytesUsed = b -> BytesReference.bytes(b).ramBytesUsed();
505504
final ShardBulkInferenceActionFilter filter = createFilter(
506505
threadPool,
@@ -554,14 +553,20 @@ public void testIndexingPressure() throws Exception {
554553

555554
IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating();
556555
assertThat(coordinatingIndexingPressure, notNullValue());
557-
verify(coordinatingIndexingPressure).increment(1, 128 + bytesUsed.apply(doc0Source));
558-
verify(coordinatingIndexingPressure).increment(1, denseModelEmbeddingBytes + bytesUsed.apply(doc1Source));
559-
verify(coordinatingIndexingPressure).increment(1, 128 + denseModelEmbeddingBytes + bytesUsed.apply(doc2Source));
560-
verify(coordinatingIndexingPressure).increment(1, denseModelEmbeddingBytes * 2L + bytesUsed.apply(doc3Source));
561-
verify(coordinatingIndexingPressure).increment(1, 128 + bytesUsed.apply(doc0UpdateSource));
556+
verify(coordinatingIndexingPressure).increment(1, bytesUsed.apply(doc0Source));
557+
verify(coordinatingIndexingPressure).increment(1, bytesUsed.apply(doc1Source));
558+
verify(coordinatingIndexingPressure).increment(1, bytesUsed.apply(doc2Source));
559+
verify(coordinatingIndexingPressure).increment(1, bytesUsed.apply(doc3Source));
560+
verify(coordinatingIndexingPressure).increment(1, bytesUsed.apply(doc4Source));
561+
verify(coordinatingIndexingPressure).increment(1, bytesUsed.apply(doc0UpdateSource));
562+
if (useLegacyFormat == false) {
563+
verify(coordinatingIndexingPressure).increment(1, bytesUsed.apply(doc1UpdateSource));
564+
}
565+
566+
verify(coordinatingIndexingPressure, times(useLegacyFormat ? 6 : 7)).increment(eq(0), longThat(l -> l > 0));
562567

563568
// Verify that the only times that increment is called are the times verified above
564-
verify(coordinatingIndexingPressure, times(5)).increment(anyInt(), anyLong());
569+
verify(coordinatingIndexingPressure, times(useLegacyFormat ? 12 : 14)).increment(anyInt(), anyLong());
565570

566571
// Verify that the coordinating indexing pressure is maintained through downstream action filters
567572
verify(coordinatingIndexingPressure, never()).close();

0 commit comments

Comments
 (0)