Skip to content

Commit 5afe92c

Browse files
PSriVarshannamsoo2
authored andcommitted
Add observability to Bedrock Titan Embedding model
Signed-off-by: PSriVarshan <[email protected]> Signed-off-by: minsoo.nam <[email protected]>
1 parent 0e50b89 commit 5afe92c

File tree

5 files changed

+119
-35
lines changed

5 files changed

+119
-35
lines changed

auto-configurations/common/spring-ai-autoconfigure-retry/src/main/java/org/springframework/ai/retry/autoconfigure/SpringAiRetryAutoConfiguration.java

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,12 @@
4141
import org.springframework.web.client.ResponseErrorHandler;
4242

4343
/**
44-
* {@link AutoConfiguration Auto-configuration} for AI Retry.
44+
* {@link AutoConfiguration Auto-configuration} for AI Retry. Provides beans for retry
45+
* template and response error handling. Handles transient and non-transient exceptions
46+
* based on HTTP status codes.
4547
*
4648
* @author Christian Tzolov
49+
* @author SriVarshan P
4750
*/
4851
@AutoConfiguration
4952
@ConditionalOnClass(RetryUtils.class)
@@ -63,9 +66,10 @@ public RetryTemplate retryTemplate(SpringAiRetryProperties properties) {
6366
.withListener(new RetryListener() {
6467

6568
@Override
66-
public <T extends Object, E extends Throwable> void onError(RetryContext context,
67-
RetryCallback<T, E> callback, Throwable throwable) {
68-
logger.warn("Retry error. Retry count:" + context.getRetryCount(), throwable);
69+
public <T, E extends Throwable> void onError(RetryContext context, RetryCallback<T, E> callback,
70+
Throwable throwable) {
71+
logger.warn("Retry error. Retry count: {}, Exception: {}", context.getRetryCount(),
72+
throwable.getMessage(), throwable);
6973
}
7074
})
7175
.build();
@@ -84,29 +88,35 @@ public boolean hasError(@NonNull ClientHttpResponse response) throws IOException
8488

8589
@Override
8690
public void handleError(@NonNull ClientHttpResponse response) throws IOException {
87-
if (response.getStatusCode().isError()) {
88-
String error = StreamUtils.copyToString(response.getBody(), StandardCharsets.UTF_8);
89-
String message = String.format("%s - %s", response.getStatusCode().value(), error);
90-
91-
// Explicitly configured transient codes
92-
if (properties.getOnHttpCodes().contains(response.getStatusCode().value())) {
93-
throw new TransientAiException(message);
94-
}
95-
96-
// onClientErrors - If true, do not throw a NonTransientAiException,
97-
// and do not attempt retry for 4xx client error codes, false by
98-
// default.
99-
if (!properties.isOnClientErrors() && response.getStatusCode().is4xxClientError()) {
100-
throw new NonTransientAiException(message);
101-
}
102-
103-
// Explicitly configured non-transient codes
104-
if (!CollectionUtils.isEmpty(properties.getExcludeOnHttpCodes())
105-
&& properties.getExcludeOnHttpCodes().contains(response.getStatusCode().value())) {
106-
throw new NonTransientAiException(message);
107-
}
91+
if (!response.getStatusCode().isError()) {
92+
return;
93+
}
94+
95+
String error = StreamUtils.copyToString(response.getBody(), StandardCharsets.UTF_8);
96+
if (error == null || error.isEmpty()) {
97+
error = "No response body available";
98+
}
99+
100+
String message = String.format("HTTP %s - %s", response.getStatusCode().value(), error);
101+
102+
// Explicitly configured transient codes
103+
if (properties.getOnHttpCodes().contains(response.getStatusCode().value())) {
108104
throw new TransientAiException(message);
109105
}
106+
107+
// Handle client errors (4xx)
108+
if (!properties.isOnClientErrors() && response.getStatusCode().is4xxClientError()) {
109+
throw new NonTransientAiException(message);
110+
}
111+
112+
// Explicitly configured non-transient codes
113+
if (!CollectionUtils.isEmpty(properties.getExcludeOnHttpCodes())
114+
&& properties.getExcludeOnHttpCodes().contains(response.getStatusCode().value())) {
115+
throw new NonTransientAiException(message);
116+
}
117+
118+
// Default to transient exception
119+
throw new TransientAiException(message);
110120
}
111121
};
112122
}

auto-configurations/models/spring-ai-autoconfigure-model-bedrock-ai/pom.xml

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@
7979
<optional>true</optional>
8080
</dependency>
8181

82+
<dependency>
83+
<groupId>org.springframework.boot</groupId>
84+
<artifactId>spring-boot-test</artifactId>
85+
<scope>test</scope>
86+
</dependency>
87+
8288
<dependency>
8389
<groupId>org.springframework.boot</groupId>
8490
<artifactId>spring-boot-configuration-processor</artifactId>
@@ -110,6 +116,11 @@
110116
<artifactId>mockito-core</artifactId>
111117
<scope>test</scope>
112118
</dependency>
113-
</dependencies>
119+
120+
<dependency>
121+
<groupId>io.micrometer</groupId>
122+
<artifactId>micrometer-observation</artifactId>
123+
</dependency>
124+
</dependencies>
114125

115126
</project>

auto-configurations/models/spring-ai-autoconfigure-model-bedrock-ai/src/main/java/org/springframework/ai/model/bedrock/titan/autoconfigure/BedrockTitanEmbeddingAutoConfiguration.java

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
package org.springframework.ai.model.bedrock.titan.autoconfigure;
1818

1919
import com.fasterxml.jackson.databind.ObjectMapper;
20+
21+
import io.micrometer.observation.ObservationRegistry;
2022
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
2123
import software.amazon.awssdk.regions.providers.AwsRegionProvider;
2224

@@ -26,6 +28,7 @@
2628
import org.springframework.ai.model.SpringAIModels;
2729
import org.springframework.ai.model.bedrock.autoconfigure.BedrockAwsConnectionConfiguration;
2830
import org.springframework.ai.model.bedrock.autoconfigure.BedrockAwsConnectionProperties;
31+
import org.springframework.beans.factory.ObjectProvider;
2932
import org.springframework.boot.autoconfigure.AutoConfiguration;
3033
import org.springframework.boot.autoconfigure.condition.ConditionalOnBean;
3134
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
@@ -40,6 +43,7 @@
4043
*
4144
* @author Christian Tzolov
4245
* @author Wei Jiang
46+
* @author SriVarshan P
4347
* @since 0.8.0
4448
*/
4549
@AutoConfiguration
@@ -56,6 +60,12 @@ public class BedrockTitanEmbeddingAutoConfiguration {
5660
public TitanEmbeddingBedrockApi titanEmbeddingBedrockApi(AwsCredentialsProvider credentialsProvider,
5761
AwsRegionProvider regionProvider, BedrockTitanEmbeddingProperties properties,
5862
BedrockAwsConnectionProperties awsProperties, ObjectMapper objectMapper) {
63+
64+
// Validate required properties
65+
if (properties.getModel() == null || awsProperties.getTimeout() == null) {
66+
throw new IllegalArgumentException("Required properties for TitanEmbeddingBedrockApi are missing.");
67+
}
68+
5969
return new TitanEmbeddingBedrockApi(properties.getModel(), credentialsProvider, regionProvider.getRegion(),
6070
objectMapper, awsProperties.getTimeout());
6171
}
@@ -64,8 +74,16 @@ public TitanEmbeddingBedrockApi titanEmbeddingBedrockApi(AwsCredentialsProvider
6474
@ConditionalOnMissingBean
6575
@ConditionalOnBean(TitanEmbeddingBedrockApi.class)
6676
public BedrockTitanEmbeddingModel titanEmbeddingModel(TitanEmbeddingBedrockApi titanEmbeddingApi,
67-
BedrockTitanEmbeddingProperties properties) {
68-
return new BedrockTitanEmbeddingModel(titanEmbeddingApi).withInputType(properties.getInputType());
77+
BedrockTitanEmbeddingProperties properties, ObjectProvider<ObservationRegistry> observationRegistry) {
78+
79+
// Validate required properties
80+
if (properties.getInputType() == null) {
81+
throw new IllegalArgumentException("InputType property for BedrockTitanEmbeddingModel is missing.");
82+
}
83+
84+
return new BedrockTitanEmbeddingModel(titanEmbeddingApi,
85+
observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP))
86+
.withInputType(properties.getInputType());
6987
}
7088

7189
}

models/spring-ai-bedrock/src/main/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModel.java

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
import org.springframework.ai.embedding.EmbeddingResponse;
3535
import org.springframework.util.Assert;
3636

37+
import io.micrometer.observation.ObservationRegistry;
38+
import io.micrometer.observation.Observation;
39+
3740
/**
3841
* {@link org.springframework.ai.embedding.EmbeddingModel} implementation that uses the
3942
* Bedrock Titan Embedding API. Titan Embedding supports text and image (encoded in
@@ -51,13 +54,17 @@ public class BedrockTitanEmbeddingModel extends AbstractEmbeddingModel {
5154

5255
private final TitanEmbeddingBedrockApi embeddingApi;
5356

57+
private final ObservationRegistry observationRegistry;
58+
5459
/**
5560
* Titan Embedding API input types. Could be either text or image (encoded in base64).
5661
*/
5762
private InputType inputType = InputType.TEXT;
5863

59-
public BedrockTitanEmbeddingModel(TitanEmbeddingBedrockApi titanEmbeddingBedrockApi) {
64+
public BedrockTitanEmbeddingModel(TitanEmbeddingBedrockApi titanEmbeddingBedrockApi,
65+
ObservationRegistry observationRegistry) {
6066
this.embeddingApi = titanEmbeddingBedrockApi;
67+
this.observationRegistry = observationRegistry;
6168
}
6269

6370
/**
@@ -78,17 +85,42 @@ public float[] embed(Document document) {
7885
public EmbeddingResponse call(EmbeddingRequest request) {
7986
Assert.notEmpty(request.getInstructions(), "At least one text is required!");
8087
if (request.getInstructions().size() != 1) {
81-
logger.warn(
82-
"Titan Embedding does not support batch embedding. Will make multiple API calls to embed(Document)");
88+
logger.warn("Titan Embedding does not support batch embedding. Multiple API calls will be made.");
8389
}
8490

8591
List<Embedding> embeddings = new ArrayList<>();
8692
var indexCounter = new AtomicInteger(0);
93+
8794
for (String inputContent : request.getInstructions()) {
8895
var apiRequest = createTitanEmbeddingRequest(inputContent, request.getOptions());
89-
TitanEmbeddingResponse response = this.embeddingApi.embedding(apiRequest);
90-
embeddings.add(new Embedding(response.embedding(), indexCounter.getAndIncrement()));
96+
97+
try {
98+
TitanEmbeddingResponse response = Observation
99+
.createNotStarted("bedrock.embedding", this.observationRegistry)
100+
.lowCardinalityKeyValue("model", "titan")
101+
.lowCardinalityKeyValue("input_type", this.inputType.name().toLowerCase())
102+
.highCardinalityKeyValue("input_length", String.valueOf(inputContent.length()))
103+
.observe(() -> {
104+
TitanEmbeddingResponse r = this.embeddingApi.embedding(apiRequest);
105+
Assert.notNull(r, "Embedding API returned null response");
106+
return r;
107+
});
108+
109+
if (response.embedding() == null || response.embedding().length == 0) {
110+
logger.warn("Empty embedding vector returned for input at index {}. Skipping.", indexCounter.get());
111+
continue;
112+
}
113+
114+
embeddings.add(new Embedding(response.embedding(), indexCounter.getAndIncrement()));
115+
}
116+
catch (Exception ex) {
117+
logger.error("Titan API embedding failed for input at index {}: {}", indexCounter.get(),
118+
summarizeInput(inputContent), ex);
119+
throw ex; // Optional: Continue instead of throwing if you want partial
120+
// success
121+
}
91122
}
123+
92124
return new EmbeddingResponse(embeddings);
93125
}
94126

@@ -117,6 +149,13 @@ public int dimensions() {
117149

118150
}
119151

152+
private String summarizeInput(String input) {
153+
if (this.inputType == InputType.IMAGE) {
154+
return "[image content omitted, length=" + input.length() + "]";
155+
}
156+
return input.length() > 100 ? input.substring(0, 100) + "..." : input;
157+
}
158+
120159
public enum InputType {
121160

122161
TEXT, IMAGE

models/spring-ai-bedrock/src/test/java/org/springframework/ai/bedrock/titan/BedrockTitanEmbeddingModelIT.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,18 @@
4040

4141
import static org.assertj.core.api.Assertions.assertThat;
4242

43+
import io.micrometer.observation.tck.TestObservationRegistry;
44+
4345
@SpringBootTest
4446
@RequiresAwsCredentials
4547
class BedrockTitanEmbeddingModelIT {
4648

4749
@Autowired
4850
private BedrockTitanEmbeddingModel embeddingModel;
4951

52+
@Autowired
53+
TestObservationRegistry observationRegistry;
54+
5055
@Test
5156
void singleEmbedding() {
5257
assertThat(this.embeddingModel).isNotNull();
@@ -82,8 +87,9 @@ public TitanEmbeddingBedrockApi titanEmbeddingApi() {
8287
}
8388

8489
@Bean
85-
public BedrockTitanEmbeddingModel titanEmbedding(TitanEmbeddingBedrockApi titanEmbeddingApi) {
86-
return new BedrockTitanEmbeddingModel(titanEmbeddingApi);
90+
public BedrockTitanEmbeddingModel titanEmbedding(TitanEmbeddingBedrockApi titanEmbeddingApi,
91+
TestObservationRegistry observationRegistry) {
92+
return new BedrockTitanEmbeddingModel(titanEmbeddingApi, observationRegistry);
8793
}
8894

8995
}

0 commit comments

Comments
 (0)