Skip to content

Commit b1a60a5

Browse files
Add Amazon Bedrock Unified Chat Completions support
1 parent 6396f0d commit b1a60a5

File tree

1 file changed

+235
-0
lines changed

1 file changed

+235
-0
lines changed
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.amazonbedrock.client;
9+
10+
import org.elasticsearch.ElasticsearchException;
11+
import org.elasticsearch.common.util.concurrent.EsExecutors;
12+
import org.elasticsearch.test.ESTestCase;
13+
import org.elasticsearch.threadpool.ThreadPool;
14+
import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults;
15+
import org.junit.Before;
16+
import org.mockito.ArgumentCaptor;
17+
import software.amazon.awssdk.services.bedrockruntime.model.*;
18+
19+
import java.util.Arrays;
20+
import java.util.concurrent.ExecutorService;
21+
import java.util.concurrent.Flow;
22+
import java.util.concurrent.atomic.AtomicInteger;
23+
24+
import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME;
25+
import static org.hamcrest.Matchers.*;
26+
import static org.hamcrest.Matchers.isA;
27+
import static org.mockito.ArgumentMatchers.any;
28+
import static org.mockito.Mockito.*;
29+
30+
public class AmazonBedrockUnifiedStreamingChatProcessorTests extends ESTestCase {
31+
private AmazonBedrockUnifiedStreamingChatProcessor processor;
32+
33+
@Before
34+
public void setUp() throws Exception {
35+
super.setUp();
36+
ThreadPool threadPool = mock();
37+
when(threadPool.executor(UTILITY_THREAD_POOL_NAME)).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE);
38+
processor = new AmazonBedrockUnifiedStreamingChatProcessor(threadPool);
39+
}
40+
41+
/**
42+
* We do not issue requests on subscribe because the downstream will control the pacing.
43+
*/
44+
public void testOnSubscribeBeforeDownstreamDoesNotRequest() {
45+
var upstream = mock(Flow.Subscription.class);
46+
processor.onSubscribe(upstream);
47+
48+
verify(upstream, never()).request(anyLong());
49+
}
50+
51+
/**
52+
* If the downstream requests data before the upstream is set, when the upstream is set, we will forward the pending requests to it.
53+
*/
54+
public void testOnSubscribeAfterDownstreamRequests() {
55+
var expectedRequestCount = randomLongBetween(1, 500);
56+
Flow.Subscriber<StreamingUnifiedChatCompletionResults.Results> subscriber = mock();
57+
doAnswer(ans -> {
58+
Flow.Subscription sub = ans.getArgument(0);
59+
sub.request(expectedRequestCount);
60+
return null;
61+
}).when(subscriber).onSubscribe(any());
62+
processor.subscribe(subscriber);
63+
64+
var upstream = mock(Flow.Subscription.class);
65+
processor.onSubscribe(upstream);
66+
67+
verify(upstream, times(1)).request(anyLong());
68+
}
69+
70+
public void testCancelDuplicateSubscriptions() {
71+
processor.onSubscribe(mock());
72+
73+
var upstream = mock(Flow.Subscription.class);
74+
processor.onSubscribe(upstream);
75+
76+
verify(upstream, times(1)).cancel();
77+
verifyNoMoreInteractions(upstream);
78+
}
79+
80+
public void testMultiplePublishesCallsOnError() {
81+
processor.subscribe(mock());
82+
83+
Flow.Subscriber<StreamingUnifiedChatCompletionResults.Results> subscriber = mock();
84+
processor.subscribe(subscriber);
85+
86+
verify(subscriber, times(1)).onError(assertArg(e -> {
87+
assertThat(e, isA(IllegalStateException.class));
88+
assertThat(e.getMessage(), equalTo("Subscriber already set."));
89+
}));
90+
}
91+
//
92+
// public void testNonDeltaBlocksAreSkipped() {
93+
// var upstream = mock(Flow.Subscription.class);
94+
// processor.onSubscribe(upstream);
95+
// var counter = new AtomicInteger();
96+
// Arrays.stream(ConverseStreamOutput.EventType.values())
97+
// .filter(type -> type != ConverseStreamOutput.EventType.CONTENT_BLOCK_DELTA)
98+
// .forEach(type -> {
99+
// ConverseStreamOutput output = mock();
100+
// when(output.sdkEventType()).thenReturn(type);
101+
// processor.onNext(output);
102+
// verify(upstream, times(counter.incrementAndGet())).request(eq(1L));
103+
// });
104+
// }
105+
//
106+
// public void testDeltaBlockForwardsDownstream() {
107+
// var expectedText = "hello";
108+
//
109+
// // mock executorservice so we can make sure we handle the response on another thread
110+
// ExecutorService executorService = mock();
111+
// ThreadPool threadPool = mock();
112+
// when(threadPool.executor(UTILITY_THREAD_POOL_NAME)).thenReturn(executorService);
113+
// processor = new AmazonBedrockUnifiedStreamingChatProcessor(threadPool);
114+
// doAnswer(ans -> {
115+
// Runnable command = ans.getArgument(0);
116+
// command.run();
117+
// return null;
118+
// }).when(executorService).execute(any());
119+
//
120+
// Flow.Subscription upstream = mock();
121+
// processor.onSubscribe(upstream);
122+
// Flow.Subscriber<StreamingUnifiedChatCompletionResults.Results> downstream = mock();
123+
// processor.subscribe(downstream);
124+
//
125+
// ConverseStreamOutput output = output(expectedText);
126+
//
127+
// processor.onNext(output);
128+
//
129+
// verifyText(downstream, expectedText);
130+
// verify(executorService, times(1)).execute(any());
131+
// verify(upstream, times(0)).request(anyLong());
132+
// }
133+
134+
private ConverseStreamOutput output(String text) {
135+
ConverseStreamOutput output = mock();
136+
when(output.sdkEventType()).thenReturn(ConverseStreamOutput.EventType.CONTENT_BLOCK_DELTA);
137+
doAnswer(ans -> {
138+
ConverseStreamResponseHandler.Visitor visitor = ans.getArgument(0);
139+
ContentBlockDelta delta = ContentBlockDelta.fromText(text);
140+
ContentBlockDeltaEvent event = ContentBlockDeltaEvent.builder().delta(delta).build();
141+
visitor.visitContentBlockDelta(event);
142+
return null;
143+
}).when(output).accept(any());
144+
return output;
145+
}
146+
147+
private void verifyText(Flow.Subscriber<StreamingUnifiedChatCompletionResults.Results> downstream, String expectedText) {
148+
verify(downstream, times(1)).onNext(assertArg(results -> {
149+
assertThat(results, notNullValue());
150+
assertThat(results.chunks().size(), equalTo(1));
151+
// assertThat(results.chunks().getFirst().choices().getFirst(), equalTo(expectedText));
152+
}));
153+
}
154+
155+
public void verifyCompleteBeforeRequest() {
156+
processor.onComplete();
157+
158+
Flow.Subscriber<StreamingUnifiedChatCompletionResults.Results> downstream = mock();
159+
var sub = ArgumentCaptor.forClass(Flow.Subscription.class);
160+
processor.subscribe(downstream);
161+
verify(downstream).onSubscribe(sub.capture());
162+
163+
sub.getValue().request(1);
164+
verify(downstream, times(1)).onComplete();
165+
}
166+
167+
public void verifyCompleteAfterRequest() {
168+
169+
Flow.Subscriber<StreamingUnifiedChatCompletionResults.Results> downstream = mock();
170+
var sub = ArgumentCaptor.forClass(Flow.Subscription.class);
171+
processor.subscribe(downstream);
172+
verify(downstream).onSubscribe(sub.capture());
173+
174+
sub.getValue().request(1);
175+
processor.onComplete();
176+
verify(downstream, times(1)).onComplete();
177+
}
178+
179+
public void verifyOnErrorBeforeRequest() {
180+
var expectedError = BedrockRuntimeException.builder().message("ahhhhhh").build();
181+
processor.onError(expectedError);
182+
183+
Flow.Subscriber<StreamingUnifiedChatCompletionResults.Results> downstream = mock();
184+
var sub = ArgumentCaptor.forClass(Flow.Subscription.class);
185+
processor.subscribe(downstream);
186+
verify(downstream).onSubscribe(sub.capture());
187+
188+
sub.getValue().request(1);
189+
verify(downstream, times(1)).onError(assertArg(e -> {
190+
assertThat(e, isA(ElasticsearchException.class));
191+
assertThat(e.getCause(), is(expectedError));
192+
}));
193+
}
194+
195+
public void verifyOnErrorAfterRequest() {
196+
var expectedError = BedrockRuntimeException.builder().message("ahhhhhh").build();
197+
198+
Flow.Subscriber<StreamingUnifiedChatCompletionResults.Results> downstream = mock();
199+
var sub = ArgumentCaptor.forClass(Flow.Subscription.class);
200+
processor.subscribe(downstream);
201+
verify(downstream).onSubscribe(sub.capture());
202+
203+
sub.getValue().request(1);
204+
processor.onError(expectedError);
205+
verify(downstream, times(1)).onError(assertArg(e -> {
206+
assertThat(e, isA(ElasticsearchException.class));
207+
assertThat(e.getCause(), is(expectedError));
208+
}));
209+
}
210+
211+
public void verifyAsyncOnCompleteIsStillDeliveredSynchronously() {
212+
mockUpstream();
213+
214+
Flow.Subscriber<StreamingUnifiedChatCompletionResults.Results> downstream = mock();
215+
var sub = ArgumentCaptor.forClass(Flow.Subscription.class);
216+
processor.subscribe(downstream);
217+
verify(downstream).onSubscribe(sub.capture());
218+
219+
sub.getValue().request(1);
220+
verify(downstream, times(1)).onNext(any());
221+
processor.onComplete();
222+
verify(downstream, times(0)).onComplete();
223+
sub.getValue().request(1);
224+
verify(downstream, times(1)).onComplete();
225+
}
226+
227+
private void mockUpstream() {
228+
Flow.Subscription upstream = mock();
229+
doAnswer(ans -> {
230+
processor.onNext(output(randomIdentifier()));
231+
return null;
232+
}).when(upstream).request(anyLong());
233+
processor.onSubscribe(upstream);
234+
}
235+
}

0 commit comments

Comments
 (0)