diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java index 631ac8c1d7a2a..32a1f82363080 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java @@ -318,8 +318,10 @@ protected InferenceAction.Request mutateInstance(InferenceAction.Request instanc @Override protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Request instance, TransportVersion version) { + InferenceAction.Request mutated; + if (version.before(TransportVersions.V_8_12_0)) { - return new InferenceAction.Request( + mutated = new InferenceAction.Request( instance.getTaskType(), instance.getInferenceEntityId(), null, @@ -330,7 +332,7 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque false ); } else if (version.before(TransportVersions.V_8_13_0)) { - return new InferenceAction.Request( + mutated = new InferenceAction.Request( instance.getTaskType(), instance.getInferenceEntityId(), null, @@ -344,7 +346,7 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque && (instance.getInputType() == InputType.UNSPECIFIED || instance.getInputType() == InputType.CLASSIFICATION || instance.getInputType() == InputType.CLUSTERING)) { - return new InferenceAction.Request( + mutated = new InferenceAction.Request( instance.getTaskType(), instance.getInferenceEntityId(), null, @@ -356,7 +358,7 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque ); } else if (version.before(TransportVersions.V_8_13_0) && (instance.getInputType() == InputType.CLUSTERING || instance.getInputType() == InputType.CLASSIFICATION)) { - return new InferenceAction.Request( + mutated = new InferenceAction.Request( instance.getTaskType(), instance.getInferenceEntityId(), null, @@ -367,7 +369,7 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque false ); } else if (version.before(TransportVersions.V_8_14_0)) { - return new InferenceAction.Request( + mutated = new InferenceAction.Request( instance.getTaskType(), instance.getInferenceEntityId(), null, @@ -379,7 +381,7 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque ); } else if (version.before(TransportVersions.INFERENCE_CONTEXT) && version.isPatchFrom(TransportVersions.INFERENCE_CONTEXT_8_X) == false) { - return new InferenceAction.Request( + mutated = new InferenceAction.Request( instance.getTaskType(), instance.getInferenceEntityId(), instance.getQuery(), @@ -390,9 +392,18 @@ protected InferenceAction.Request mutateInstanceForVersion(InferenceAction.Reque false, InferenceContext.EMPTY_INSTANCE ); + } else { + mutated = instance; } - return instance; + // We always assume that a request has been rerouted, if it came from a node without adaptive rate limiting + if (version.before(TransportVersions.INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING)) { + mutated.setHasBeenRerouted(true); + } else { + mutated.setHasBeenRerouted(instance.hasBeenRerouted()); + } + + return mutated; } public void testWriteTo_WhenVersionIsOnAfterUnspecifiedAdded() throws IOException {