Skip to content

Commit a64b703

Browse files
author
Jicheng Lu
committed
add azure image variation
1 parent 357b185 commit a64b703

File tree

5 files changed

+163
-12
lines changed

5 files changed

+163
-12
lines changed

src/Plugins/BotSharp.Plugin.AzureOpenAI/AzureOpenAiPlugin.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ public void RegisterDI(IServiceCollection services, IConfiguration config)
2828

2929
services.AddScoped<ITextCompletion, TextCompletionProvider>();
3030
services.AddScoped<IChatCompletion, ChatCompletionProvider>();
31-
services.AddScoped<IImageGeneration, ImageGenerationProvider>();
3231
services.AddScoped<ITextEmbedding, TextEmbeddingProvider>();
32+
services.AddScoped<IImageGeneration, ImageGenerationProvider>();
33+
services.AddScoped<IImageVariation, ImageVariationProvider>();
3334
}
3435
}

src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Image/ImageGenerationProvider.cs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,15 @@ public async Task<RoleDialogModel> GetImageGeneration(Agent agent, RoleDialogMod
3535
var response = imageClient.GenerateImages(prompt, imageCount, options);
3636
var values = response.Value;
3737

38-
var images = new List<ImageGeneration>();
38+
var generatedImages = new List<ImageGeneration>();
3939
foreach (var value in values)
4040
{
4141
if (value == null) continue;
4242

43-
var image = new ImageGeneration { Description = value?.RevisedPrompt ?? string.Empty };
43+
var generatedImage = new ImageGeneration { Description = value?.RevisedPrompt ?? string.Empty };
4444
if (options.ResponseFormat == GeneratedImageFormat.Uri)
4545
{
46-
image.ImageUrl = value?.ImageUri?.AbsoluteUri ?? string.Empty;
46+
generatedImage.ImageUrl = value?.ImageUri?.AbsoluteUri ?? string.Empty;
4747
}
4848
else if (options.ResponseFormat == GeneratedImageFormat.Bytes)
4949
{
@@ -53,18 +53,18 @@ public async Task<RoleDialogModel> GetImageGeneration(Agent agent, RoleDialogMod
5353
{
5454
base64Str = Convert.ToBase64String(bytes);
5555
}
56-
image.ImageData = base64Str;
56+
generatedImage.ImageData = base64Str;
5757
}
5858

59-
images.Add(image);
59+
generatedImages.Add(generatedImage);
6060
}
6161

62-
var content = string.Join("\r\n", images.Select(x => x.Description));
62+
var content = string.Join("\r\n", generatedImages.Where(x => !string.IsNullOrWhiteSpace(x.Description)).Select(x => x.Description));
6363
var responseMessage = new RoleDialogModel(AgentRole.Assistant, content)
6464
{
6565
CurrentAgentId = agent.Id,
6666
MessageId = message?.MessageId ?? string.Empty,
67-
GeneratedImages = images
67+
GeneratedImages = generatedImages
6868
};
6969

7070
// After
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
using OpenAI.Images;
2+
3+
namespace BotSharp.Plugin.AzureOpenAI.Providers.Image;
4+
5+
public class ImageVariationProvider : IImageVariation
6+
{
7+
protected readonly AzureOpenAiSettings _settings;
8+
protected readonly IServiceProvider _services;
9+
protected readonly ILogger<ImageVariationProvider> _logger;
10+
11+
private const int DEFAULT_IMAGE_COUNT = 1;
12+
private const int IMAGE_COUNT_LIMIT = 5;
13+
14+
protected string _model;
15+
16+
public virtual string Provider => "azure-openai";
17+
18+
public ImageVariationProvider(
19+
AzureOpenAiSettings settings,
20+
ILogger<ImageVariationProvider> logger,
21+
IServiceProvider services)
22+
{
23+
_settings = settings;
24+
_services = services;
25+
_logger = logger;
26+
}
27+
28+
public async Task<RoleDialogModel> GetImageVariation(Agent agent, RoleDialogModel message, Stream image, string imageFileName)
29+
{
30+
var client = ProviderHelper.GetClient(Provider, _model, _services);
31+
var (imageCount, options) = PrepareOptions();
32+
var imageClient = client.GetImageClient(_model);
33+
34+
var response = imageClient.GenerateImageVariations(image, imageFileName, imageCount, options);
35+
var values = response.Value;
36+
37+
var generatedImages = new List<ImageGeneration>();
38+
foreach (var value in values)
39+
{
40+
if (value == null) continue;
41+
42+
var generatedImage = new ImageGeneration { Description = value?.RevisedPrompt ?? string.Empty };
43+
if (options.ResponseFormat == GeneratedImageFormat.Uri)
44+
{
45+
generatedImage.ImageUrl = value?.ImageUri?.AbsoluteUri ?? string.Empty;
46+
}
47+
else if (options.ResponseFormat == GeneratedImageFormat.Bytes)
48+
{
49+
var base64Str = string.Empty;
50+
var bytes = value?.ImageBytes?.ToArray();
51+
if (!bytes.IsNullOrEmpty())
52+
{
53+
base64Str = Convert.ToBase64String(bytes);
54+
}
55+
generatedImage.ImageData = base64Str;
56+
}
57+
58+
generatedImages.Add(generatedImage);
59+
}
60+
61+
var content = string.Join("\r\n", generatedImages.Where(x => !string.IsNullOrWhiteSpace(x.Description)).Select(x => x.Description));
62+
var responseMessage = new RoleDialogModel(AgentRole.Assistant, content)
63+
{
64+
CurrentAgentId = agent.Id,
65+
MessageId = message?.MessageId ?? string.Empty,
66+
GeneratedImages = generatedImages
67+
};
68+
69+
return await Task.FromResult(responseMessage);
70+
}
71+
72+
public void SetModelName(string model)
73+
{
74+
_model = model;
75+
}
76+
77+
private (int, ImageVariationOptions) PrepareOptions()
78+
{
79+
var state = _services.GetRequiredService<IConversationStateService>();
80+
var size = state.GetState("image_size");
81+
var format = state.GetState("image_format");
82+
var count = GetImageCount(state.GetState("image_count", "1"));
83+
84+
var options = new ImageVariationOptions
85+
{
86+
Size = GetImageSize(size),
87+
ResponseFormat = GetImageFormat(format)
88+
};
89+
return (count, options);
90+
}
91+
92+
private GeneratedImageSize GetImageSize(string size)
93+
{
94+
var value = !string.IsNullOrEmpty(size) ? size : "1024x1024";
95+
96+
GeneratedImageSize retSize;
97+
switch (value)
98+
{
99+
case "256x256":
100+
retSize = GeneratedImageSize.W256xH256;
101+
break;
102+
case "512x512":
103+
retSize = GeneratedImageSize.W512xH512;
104+
break;
105+
case "1024x1024":
106+
retSize = GeneratedImageSize.W1024xH1024;
107+
break;
108+
case "1024x1792":
109+
retSize = GeneratedImageSize.W1024xH1792;
110+
break;
111+
case "1792x1024":
112+
retSize = GeneratedImageSize.W1792xH1024;
113+
break;
114+
default:
115+
retSize = GeneratedImageSize.W1024xH1024;
116+
break;
117+
}
118+
119+
return retSize;
120+
}
121+
122+
private GeneratedImageFormat GetImageFormat(string format)
123+
{
124+
var value = !string.IsNullOrEmpty(format) ? format : "uri";
125+
126+
GeneratedImageFormat retFormat;
127+
switch (value)
128+
{
129+
case "uri":
130+
retFormat = GeneratedImageFormat.Uri;
131+
break;
132+
case "bytes":
133+
retFormat = GeneratedImageFormat.Bytes;
134+
break;
135+
default:
136+
retFormat = GeneratedImageFormat.Uri;
137+
break;
138+
}
139+
140+
return retFormat;
141+
}
142+
143+
private int GetImageCount(string count)
144+
{
145+
if (!int.TryParse(count, out var retCount))
146+
{
147+
return DEFAULT_IMAGE_COUNT;
148+
}
149+
150+
return retCount > 0 && retCount <= IMAGE_COUNT_LIMIT ? retCount : DEFAULT_IMAGE_COUNT;
151+
}
152+
}

src/Plugins/BotSharp.Plugin.OpenAI/Providers/Image/ImageGenerationProvider.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ public async Task<RoleDialogModel> GetImageGeneration(Agent agent, RoleDialogMod
5959
generatedImages.Add(generatedImage);
6060
}
6161

62-
var content = string.Join("\r\n", generatedImages.Select(x => x.Description));
62+
var content = string.Join("\r\n", generatedImages.Where(x => !string.IsNullOrWhiteSpace(x.Description)).Select(x => x.Description));
6363
var responseMessage = new RoleDialogModel(AgentRole.Assistant, content)
6464
{
6565
CurrentAgentId = agent.Id,

src/Plugins/BotSharp.Plugin.OpenAI/Providers/Image/ImageVariationProvider.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ public async Task<RoleDialogModel> GetImageVariation(Agent agent, RoleDialogMode
5858
generatedImages.Add(generatedImage);
5959
}
6060

61-
var content = string.Join("\r\n", generatedImages.Select(x => x.Description));
61+
var content = string.Join("\r\n", generatedImages.Where(x => !string.IsNullOrWhiteSpace(x.Description)).Select(x => x.Description));
6262
var responseMessage = new RoleDialogModel(AgentRole.Assistant, content)
6363
{
6464
CurrentAgentId = agent.Id,
@@ -78,8 +78,6 @@ public void SetModelName(string model)
7878
{
7979
var state = _services.GetRequiredService<IConversationStateService>();
8080
var size = state.GetState("image_size");
81-
var quality = state.GetState("image_quality");
82-
var style = state.GetState("image_style");
8381
var format = state.GetState("image_format");
8482
var count = GetImageCount(state.GetState("image_count", "1"));
8583

0 commit comments

Comments
 (0)