Skip to content

Commit 0917ef2

Browse files
authored
feat: ๐Ÿ–ผ๏ธ SDXL prompt optimizer feature (#101)
* feat: ๐Ÿ–ผ๏ธ SDXL prompt optimizer feature * fix: ๐Ÿ› Typo on class name * fix: ๐Ÿ› Typo on user prompt
1 parent 30d8973 commit 0917ef2

File tree

3 files changed

+144
-0
lines changed

3 files changed

+144
-0
lines changed

โ€Žai/ai-endpoints/README.mdโ€Ž

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ Don't hesitate to use the source code and give us feedback.
1515

1616
### โ˜•๏ธ Java demos โ˜•๏ธ
1717

18+
- [Function calling with LangChain4J](./function-calling-langchain4j)
1819
- [Simple Structured Output](./structured-output-langchain4j/)
1920
- [Natural Language Processing](./java-nlp)
2021
- [Chatbot with LangChain4j](./java-langchain4j-chatbot/): blocking mode, streaming mode and RAG mode.
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
///usr/bin/env jbang "$0" "$@" ; exit $?
2+
//JAVA 24+
3+
//PREVIEW
4+
//DEPS dev.langchain4j:langchain4j:1.0.1 dev.langchain4j:langchain4j-mistral-ai:1.0.1-beta6
5+
6+
import dev.langchain4j.agent.tool.P;
7+
import dev.langchain4j.agent.tool.Tool;
8+
import dev.langchain4j.memory.ChatMemory;
9+
import dev.langchain4j.memory.chat.MessageWindowChatMemory;
10+
import dev.langchain4j.model.chat.ChatModel;
11+
import dev.langchain4j.model.mistralai.MistralAiChatModel;
12+
import dev.langchain4j.service.AiServices;
13+
import dev.langchain4j.service.SystemMessage;
14+
import dev.langchain4j.service.UserMessage;
15+
16+
import java.io.IOException;
17+
import java.net.URI;
18+
import java.net.http.HttpClient;
19+
import java.net.http.HttpRequest;
20+
import java.net.http.HttpResponse;
21+
import java.nio.file.Files;
22+
import java.nio.file.Path;
23+
import java.util.Scanner;
24+
25+
class ImageGenTools {
26+
27+
@Tool("""
28+
Tool to create an image with Stable Diffusion XL given a prompt and a negative prompt.
29+
""")
30+
void generateImage(@P("Prompt that explains the image") String prompt, @P("Negative prompt that explains what the image must not contains") String negativePrompt) throws IOException, InterruptedException {
31+
System.out.println("Prompt: " + prompt);
32+
System.out.println("Negative prompt: " + negativePrompt);
33+
34+
HttpRequest httpRequest = HttpRequest.newBuilder()
35+
.uri(URI.create(System.getenv("OVH_AI_ENDPOINTS_SD_URL")))
36+
.POST(HttpRequest.BodyPublishers.ofString("""
37+
{"prompt": "%s",
38+
"negative_prompt": "%s"}
39+
""".formatted(prompt, negativePrompt)))
40+
.header("accept", "application/octet-stream")
41+
.header("Content-Type", "application/json")
42+
.header("Authorization", "Bearer " + System.getenv("OVH_AI_ENDPOINTS_SDXL_ACCESS_TOKEN"))
43+
.build();
44+
45+
HttpResponse<byte[]> response = HttpClient.newHttpClient()
46+
.send(httpRequest, HttpResponse.BodyHandlers.ofByteArray());
47+
48+
System.out.println("SDXL status code: " + response.statusCode());
49+
Files.write(Path.of("generated-image.jpeg"), response.body());
50+
}
51+
}
52+
53+
/// Chatbot definition.
54+
/// The goal of the chatbot is to build a powerful prompt for Stable diffusion XML.
55+
interface ChatBot {
56+
@SystemMessage("""
57+
Your are an expert of using the Stable Diffusion XL model.
58+
The user explains in natural language what kind of image he wants.
59+
You must do the following steps:
60+
- Understand the user's request.
61+
- Generate the two kinds of prompts for stable diffusion: the prompt and the negative prompt
62+
- the prompts must be in english and detailed and optimized for the Stable Diffusion XL model.
63+
- once and only once you have this two prompts call the tool with the two prompts.
64+
If asked about to create an image, you MUST call the `generateImage` function.
65+
""")
66+
@UserMessage("Create an image with stable diffusion XL following this description: {{userMessage}}")
67+
String chat(String userMessage);
68+
}
69+
70+
void main() throws Exception {
71+
72+
// Main chatbot configuration, choose on of the available models on the AI Endpoints catalog (https://endpoints.ai.cloud.ovh.net/catalog)
73+
ChatModel chatModel = MistralAiChatModel.builder()
74+
.apiKey(System.getenv("OVH_AI_ENDPOINTS_ACCESS_TOKEN"))
75+
.baseUrl(System.getenv("OVH_AI_ENDPOINTS_MODEL_URL"))
76+
.modelName(System.getenv("OVH_AI_ENDPOINTS_MODEL_NAME"))
77+
.logRequests(false)
78+
.logResponses(false)
79+
// To have more deterministic outputs, set temperature to 0.
80+
.temperature(0.0)
81+
.build();
82+
83+
// Add memory to fine tune the SDXL prompt.
84+
ChatMemory chatMemory = MessageWindowChatMemory.withMaxMessages(10);
85+
86+
// Build the chatbot thanks to LangChain4J AI Servises mode
87+
ChatBot chatBot = AiServices.builder(ChatBot.class)
88+
.chatModel(chatModel)
89+
.tools(new ImageGenTools())
90+
.chatMemory(chatMemory)
91+
.build();
92+
93+
// Start the conversation loop (enter "exit" to quit)
94+
String userInput = "";
95+
Scanner scanner = new Scanner(System.in);
96+
while (true) {
97+
System.out.print("Enter your message: ");
98+
userInput = scanner.nextLine();
99+
if (userInput.equalsIgnoreCase("exit")) break;
100+
System.out.println("Response: " + chatBot.chat(userInput));
101+
}
102+
scanner.close();
103+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Function Calling with LangChain4j and AI Endpoints
2+
3+
### ๐Ÿงฐ Pre requisites ๐Ÿงฐ
4+
5+
- Java 21+ installed (with preview mode enabled)
6+
- AI Endpoints API token
7+
- model to use: any of the LLM instruct models
8+
- have the following environment variables created:
9+
- OVH_AI_ENDPOINTS_ACCESS_TOKEN: the API token, see [documentation](https://help.ovhcloud.com/csm/en-gb-public-cloud-ai-endpoints-getting-started?id=kb_article_view&sysparm_article=KB0065401#generating-your-first-api-access-key) to know how to generate it
10+
- OVH_AI_ENDPOINTS_MODEL_URL: URL of the model, see [AI Endpoints website](https://endpoints.ai.cloud.ovh.net/) to know how to get it.
11+
- OVH_AI_ENDPOINTS_MODEL_NAME: model name, see [AI Endpoints website](https://endpoints.ai.cloud.ovh.net/) to know how to get it.
12+
- [JBang](https://www.jbang.dev/documentation/guide/latest/index.html) installed
13+
14+
## โšก๏ธ How to use the project โšก๏ธ
15+
16+
- run `jbang ImageGeneration.java` command
17+
- explain your image with natural language
18+
- see the result in the `generated-image.jpeg`
19+
- fine tune your prompt to get better results
20+
- enter `exit` to quit the application
21+
22+
## ๐Ÿ—บ๏ธ Architecture ๐Ÿ—บ๏ธ
23+
24+
```mermaid
25+
graph RL
26+
subgraph User app
27+
A[Chatbot]
28+
D[ImageGenTools]
29+
end
30+
subgraph AI Endpoints
31+
E[Stable Diffusion XL]
32+
B[LLM Model]
33+
end
34+
A[Chatbot] -->| 1-Ask to create an image in natural language | B[LLM Model]
35+
B -->| 2-Create a SDXL prompt + tool name | A
36+
A -->| 3-Call generateImage | D[ImageGenTools]
37+
D -->| 4-Call SDXL with generated prompts | E[Stable Diffusion XL]
38+
E -->| 5-Generated image | D
39+
B -->| 6-Final response | A
40+
```

0 commit comments

Comments
ย (0)