Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.sap.ai.sdk.app.services.OpenAiService;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionOutput;
import com.sap.cloud.sdk.cloudplatform.thread.ThreadContextExecutors;
import java.io.IOException;
import java.util.Arrays;
import javax.annotation.Nonnull;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
Expand Down Expand Up @@ -36,7 +40,7 @@ ResponseEntity<String> chatCompletion(
@RequestHeader(value = "accept", required = false) final String accept)
throws JsonProcessingException {
final var response = service.chatCompletion("Who is the prettiest");
if (accept.equals("application/json")) {
if ("application/json".equals(accept)) {
return ResponseEntity.ok()
.contentType(MediaType.APPLICATION_JSON)
.body(mapper.writeValueAsString(response));
Expand All @@ -54,7 +58,26 @@ ResponseEntity<String> chatCompletion(
@Nonnull
ResponseEntity<ResponseBodyEmitter> streamChatCompletionDeltas() {
final var message = "Can you give me the first 100 numbers of the Fibonacci sequence?";
return service.streamChatCompletionDeltas(message);
final var stream = service.streamChatCompletionDeltas(message);
final var emitter = new ResponseBodyEmitter();
final Runnable consumeStream =
() -> {
final var totalOutput = new OpenAiChatCompletionOutput();
// try-with-resources ensures the stream is closed
try (stream) {
stream
.peek(totalOutput::addDelta)
.forEach(delta -> send(emitter, delta.getDeltaContent()));
} finally {
send(emitter, "\n\n-----Total Output-----\n\n" + objectToJson(totalOutput));
emitter.complete();
}
};

ThreadContextExecutors.getExecutor().execute(consumeStream);

// TEXT_EVENT_STREAM allows the browser to display the content as it is streamed
return ResponseEntity.ok().contentType(MediaType.TEXT_EVENT_STREAM).body(emitter);
}

/**
Expand All @@ -67,7 +90,51 @@ ResponseEntity<ResponseBodyEmitter> streamChatCompletionDeltas() {
@Nonnull
ResponseEntity<ResponseBodyEmitter> streamChatCompletion() {
final var message = "Can you give me the first 100 numbers of the Fibonacci sequence?";
return service.streamChatCompletion(message);
final var stream = service.streamChatCompletion(message);
final var emitter = new ResponseBodyEmitter();

final Runnable consumeStream =
() -> {
try (stream) {
stream.forEach(deltaMessage -> send(emitter, deltaMessage));
} finally {
emitter.complete();
}
};

ThreadContextExecutors.getExecutor().execute(consumeStream);

// TEXT_EVENT_STREAM allows the browser to display the content as it is streamed
return ResponseEntity.ok().contentType(MediaType.TEXT_EVENT_STREAM).body(emitter);
}

/**
* Send a chunk to the emitter
*
* @param emitter The emitter to send the chunk to
* @param chunk The chunk to send
*/
public static void send(@Nonnull final ResponseBodyEmitter emitter, @Nonnull final String chunk) {
try {
emitter.send(chunk);
} catch (final IOException e) {
log.error(Arrays.toString(e.getStackTrace()));
emitter.completeWithError(e);
}
}

/**
* Convert an object to JSON
*
* @param obj The object to convert
* @return The JSON representation of the object
*/
private static String objectToJson(@Nonnull final Object obj) {
try {
return new ObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(obj);
} catch (final JsonProcessingException ignored) {
return "Could not parse object to JSON";
}
}

/**
Expand All @@ -83,7 +150,7 @@ ResponseEntity<String> chatCompletionImage(
final var response =
service.chatCompletionImage(
"https://upload.wikimedia.org/wikipedia/commons/thumb/5/59/SAP_2011_logo.svg/440px-SAP_2011_logo.svg.png");
if (accept.equals("application/json")) {
if ("application/json".equals(accept)) {
return ResponseEntity.ok()
.contentType(MediaType.APPLICATION_JSON)
.body(mapper.writeValueAsString(response));
Expand All @@ -103,7 +170,7 @@ ResponseEntity<String> chatCompletionTools(
throws JsonProcessingException {
final var response =
service.chatCompletionTools("Calculate the Fibonacci number for given sequence index.");
if (accept.equals("application/json")) {
if ("application/json".equals(accept)) {
return ResponseEntity.ok()
.contentType(MediaType.APPLICATION_JSON)
.body(mapper.writeValueAsString(response));
Expand Down Expand Up @@ -139,7 +206,7 @@ ResponseEntity<String> chatCompletionWithResource(
throws JsonProcessingException {
final var response =
service.chatCompletionWithResource(resourceGroup, "Where is the nearest coffee shop?");
if (accept.equals("application/json")) {
if ("application/json".equals(accept)) {
return ResponseEntity.ok()
.contentType(MediaType.APPLICATION_JSON)
.body(mapper.writeValueAsString(response));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package com.sap.ai.sdk.app.controllers;

import static com.sap.ai.sdk.app.controllers.OpenAiController.send;

import com.fasterxml.jackson.annotation.JsonAutoDetect.Visibility;
import com.fasterxml.jackson.annotation.PropertyAccessor;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.sap.ai.sdk.app.services.OrchestrationService;
import com.sap.ai.sdk.orchestration.AzureFilterThreshold;
import com.sap.ai.sdk.orchestration.model.DPIEntities;
import com.sap.cloud.sdk.cloudplatform.thread.ThreadContextExecutors;
import javax.annotation.Nonnull;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
Expand Down Expand Up @@ -40,7 +43,7 @@ ResponseEntity<String> completion(
@RequestHeader(value = "accept", required = false) final String accept)
throws JsonProcessingException {
final var response = service.completion("HelloWorld!");
if (accept.equals("application/json")) {
if ("application/json".equals(accept)) {
return ResponseEntity.ok()
.contentType(MediaType.APPLICATION_JSON)
.body(mapper.writeValueAsString(response));
Expand All @@ -56,7 +59,25 @@ ResponseEntity<String> completion(
@GetMapping("/streamChatCompletion")
@Nonnull
ResponseEntity<ResponseBodyEmitter> streamChatCompletion() {
return service.streamChatCompletion("developing a software project");
final var stream = service.streamChatCompletion("developing a software project");
final var emitter = new ResponseBodyEmitter();
final Runnable consumeStream =
() -> {
try (stream) {
stream.forEach(
deltaMessage -> {
log.info("Service: {}", deltaMessage);
send(emitter, deltaMessage);
});
} finally {
emitter.complete();
}
};

ThreadContextExecutors.getExecutor().execute(consumeStream);

// TEXT_EVENT_STREAM allows the browser to display the content as it is streamed
return ResponseEntity.ok().contentType(MediaType.TEXT_EVENT_STREAM).body(emitter);
}

/**
Expand All @@ -72,7 +93,7 @@ ResponseEntity<Object> template(
@RequestHeader(value = "accept", required = false) final String accept)
throws JsonProcessingException {
final var response = service.template("German");
if (accept.equals("application/json")) {
if ("application/json".equals(accept)) {
return ResponseEntity.ok()
.contentType(MediaType.APPLICATION_JSON)
.body(mapper.writeValueAsString(response));
Expand All @@ -91,7 +112,7 @@ ResponseEntity<String> messagesHistory(
@RequestHeader(value = "accept", required = false) final String accept)
throws JsonProcessingException {
final var response = service.messagesHistory("What is the capital of France?");
if (accept.equals("application/json")) {
if ("application/json".equals(accept)) {
return ResponseEntity.ok()
.contentType(MediaType.APPLICATION_JSON)
.body(mapper.writeValueAsString(response));
Expand All @@ -118,7 +139,7 @@ ResponseEntity<String> filter(
@Nonnull @PathVariable("policy") final AzureFilterThreshold policy)
throws JsonProcessingException {
final var response = service.filter(policy, "the downtown area");
if (accept.equals("application/json")) {
if ("application/json".equals(accept)) {
return ResponseEntity.ok()
.contentType(MediaType.APPLICATION_JSON)
.body(mapper.writeValueAsString(response));
Expand All @@ -142,7 +163,7 @@ ResponseEntity<String> maskingAnonymization(
@RequestHeader(value = "accept", required = false) final String accept)
throws JsonProcessingException {
final var response = service.maskingAnonymization(DPIEntities.PERSON);
if (accept.equals("application/json")) {
if ("application/json".equals(accept)) {
return ResponseEntity.ok()
.contentType(MediaType.APPLICATION_JSON)
.body(mapper.writeValueAsString(response));
Expand All @@ -162,7 +183,7 @@ public ResponseEntity<String> completionWithResourceGroup(
@PathVariable("resourceGroup") @Nonnull final String resourceGroup)
throws JsonProcessingException {
final var response = service.completionWithResourceGroup(resourceGroup, "Hello world!");
if (accept.equals("application/json")) {
if ("application/json".equals(accept)) {
return ResponseEntity.ok()
.contentType(MediaType.APPLICATION_JSON)
.body(mapper.writeValueAsString(response));
Expand All @@ -185,7 +206,7 @@ ResponseEntity<String> maskingPseudonymization(
@RequestHeader(value = "accept", required = false) final String accept)
throws JsonProcessingException {
final var response = service.maskingPseudonymization(DPIEntities.PERSON);
if (accept.equals("application/json")) {
if ("application/json".equals(accept)) {
return ResponseEntity.ok()
.contentType(MediaType.APPLICATION_JSON)
.body(mapper.writeValueAsString(response));
Expand All @@ -206,7 +227,7 @@ ResponseEntity<String> grounding(
@RequestHeader(value = "accept", required = false) final String accept)
throws JsonProcessingException {
final var response = service.grounding("What does Joule do?");
if (accept.equals("application/json")) {
if ("application/json".equals(accept)) {
return ResponseEntity.ok()
.contentType(MediaType.APPLICATION_JSON)
.body(mapper.writeValueAsString(response));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,22 @@
import static com.sap.ai.sdk.foundationmodels.openai.OpenAiModel.TEXT_EMBEDDING_ADA_002;
import static com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionTool.ToolType.FUNCTION;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.sap.ai.sdk.core.AiCoreService;
import com.sap.ai.sdk.foundationmodels.openai.OpenAiClient;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionDelta;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionFunction;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionOutput;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionParameters;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatCompletionTool;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiChatMessage;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiEmbeddingOutput;
import com.sap.ai.sdk.foundationmodels.openai.model.OpenAiEmbeddingParameters;
import com.sap.cloud.sdk.cloudplatform.thread.ThreadContextExecutors;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
import javax.annotation.Nonnull;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.ResponseBodyEmitter;

/** Service class for OpenAI service */
@Service
Expand All @@ -50,34 +44,13 @@ public OpenAiChatCompletionOutput chatCompletion(@Nonnull final String prompt) {
* @return the emitter that streams the assistant message response
*/
@Nonnull
public ResponseEntity<ResponseBodyEmitter> streamChatCompletionDeltas(
public Stream<OpenAiChatCompletionDelta> streamChatCompletionDeltas(
@Nonnull final String message) {
final var request =
new OpenAiChatCompletionParameters()
.addMessages(new OpenAiChatMessage.OpenAiChatUserMessage().addText(message));

final var stream = OpenAiClient.forModel(GPT_35_TURBO).streamChatCompletionDeltas(request);

final var emitter = new ResponseBodyEmitter();

final Runnable consumeStream =
() -> {
final var totalOutput = new OpenAiChatCompletionOutput();
// try-with-resources ensures the stream is closed
try (stream) {
stream
.peek(totalOutput::addDelta)
.forEach(delta -> send(emitter, delta.getDeltaContent()));
} finally {
send(emitter, "\n\n-----Total Output-----\n\n" + objectToJson(totalOutput));
emitter.complete();
}
};

ThreadContextExecutors.getExecutor().execute(consumeStream);

// TEXT_EVENT_STREAM allows the browser to display the content as it is streamed
return ResponseEntity.ok().contentType(MediaType.TEXT_EVENT_STREAM).body(emitter);
return OpenAiClient.forModel(GPT_35_TURBO).streamChatCompletionDeltas(request);
}

/**
Expand All @@ -86,50 +59,10 @@ public ResponseEntity<ResponseBodyEmitter> streamChatCompletionDeltas(
* @return the emitter that streams the assistant message response
*/
@Nonnull
public ResponseEntity<ResponseBodyEmitter> streamChatCompletion(@Nonnull final String message) {
final var stream =
OpenAiClient.forModel(GPT_35_TURBO)
.withSystemPrompt("Be a good, honest AI and answer the following question:")
.streamChatCompletion(message);

final var emitter = new ResponseBodyEmitter();

final Runnable consumeStream =
() -> {
try (stream) {
stream.forEach(deltaMessage -> send(emitter, deltaMessage));
} finally {
emitter.complete();
}
};

ThreadContextExecutors.getExecutor().execute(consumeStream);

// TEXT_EVENT_STREAM allows the browser to display the content as it is streamed
return ResponseEntity.ok().contentType(MediaType.TEXT_EVENT_STREAM).body(emitter);
}

private static String objectToJson(@Nonnull final Object obj) {
try {
return new ObjectMapper().writerWithDefaultPrettyPrinter().writeValueAsString(obj);
} catch (final JsonProcessingException ignored) {
return "Could not parse object to JSON";
}
}

/**
* Send a chunk to the emitter
*
* @param emitter The emitter to send the chunk to
* @param chunk The chunk to send
*/
public static void send(@Nonnull final ResponseBodyEmitter emitter, @Nonnull final String chunk) {
try {
emitter.send(chunk);
} catch (final IOException e) {
log.error(Arrays.toString(e.getStackTrace()));
emitter.completeWithError(e);
}
public Stream<String> streamChatCompletion(@Nonnull final String message) {
return OpenAiClient.forModel(GPT_35_TURBO)
.withSystemPrompt("Be a good, honest AI and answer the following question:")
.streamChatCompletion(message);
}

/**
Expand Down
Loading
Loading