Skip to content

Commit 5637a41

Browse files
authored
Merge pull request #555 from iceljc/features/add-image-mask-edit
add image mask edit
2 parents 37fbd6d + b1fb9d6 commit 5637a41

File tree

27 files changed

+578
-140
lines changed

27 files changed

+578
-140
lines changed

src/Infrastructure/BotSharp.Abstraction/Conversations/Models/IncomingMessageModel.cs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,4 @@ public class IncomingMessageModel : MessageConfig
1212
/// Postback message
1313
/// </summary>
1414
public PostbackMessageModel? Postback { get; set; }
15-
16-
public List<BotSharpFile> Files { get; set; } = new List<BotSharpFile>();
1715
}

src/Infrastructure/BotSharp.Abstraction/Files/IBotSharpFileService.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ Task<IEnumerable<MessageFileModel>> GetChatFiles(string conversationId, string s
4545

4646
#region Image
4747
Task<RoleDialogModel> GenerateImage(string? provider, string? model, string text);
48-
Task<RoleDialogModel> VarifyImage(string? provider, string? model, BotSharpFile file);
48+
Task<RoleDialogModel> VaryImage(string? provider, string? model, BotSharpFile image);
49+
Task<RoleDialogModel> EditImage(string? provider, string? model, string text, BotSharpFile image);
50+
Task<RoleDialogModel> EditImage(string? provider, string? model, string text, BotSharpFile image, BotSharpFile mask);
4951
#endregion
5052

5153
#region Pdf
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
namespace BotSharp.Abstraction.Files.Models;
2+
3+
public class InputMessageFiles
4+
{
5+
public List<BotSharpFile> Files { get; set; } = new List<BotSharpFile>();
6+
public BotSharpFile? Mask { get; set; }
7+
}

src/Infrastructure/BotSharp.Abstraction/Files/Models/LlmFileContext.cs

Lines changed: 0 additions & 16 deletions
This file was deleted.

src/Infrastructure/BotSharp.Abstraction/MLTasks/IImageCompletion.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,8 @@ public interface IImageCompletion
1818
Task<RoleDialogModel> GetImageGeneration(Agent agent, RoleDialogModel message);
1919

2020
Task<RoleDialogModel> GetImageVariation(Agent agent, RoleDialogModel message, Stream image, string imageFileName);
21+
22+
Task<RoleDialogModel> GetImageEdits(Agent agent, RoleDialogModel message, Stream image, string imageFileName);
23+
24+
Task<RoleDialogModel> GetImageEdits(Agent agent, RoleDialogModel message, Stream image, string imageFileName, Stream mask, string maskFileName);
2125
}

src/Infrastructure/BotSharp.Abstraction/Models/MessageConfig.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
namespace BotSharp.Abstraction.Models;
22

3-
public class MessageConfig
3+
public class MessageConfig : InputMessageFiles
44
{
55
/// <summary>
66
/// Completion Provider

src/Infrastructure/BotSharp.Core/Files/Services/BotSharpFileService.Image.cs

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,77 @@ public async Task<RoleDialogModel> GenerateImage(string? provider, string? model
1414
return message;
1515
}
1616

17-
public async Task<RoleDialogModel> VarifyImage(string? provider, string? model, BotSharpFile file)
17+
public async Task<RoleDialogModel> VaryImage(string? provider, string? model, BotSharpFile image)
1818
{
19-
if (string.IsNullOrWhiteSpace(file?.FileUrl) && string.IsNullOrWhiteSpace(file?.FileData))
19+
if (string.IsNullOrWhiteSpace(image?.FileUrl) && string.IsNullOrWhiteSpace(image?.FileData))
2020
{
21-
throw new ArgumentException($"Please fill in at least file url or file data!");
21+
throw new ArgumentException($"Cannot find image url or data!");
2222
}
2323

2424
var completion = CompletionProvider.GetImageCompletion(_services, provider: provider ?? "openai", model: model ?? "dall-e-2");
25-
var bytes = await DownloadFile(file);
25+
var bytes = await DownloadFile(image);
2626
using var stream = new MemoryStream();
2727
stream.Write(bytes, 0, bytes.Length);
2828
stream.Position = 0;
2929

3030
var message = await completion.GetImageVariation(new Agent()
3131
{
3232
Id = Guid.Empty.ToString()
33-
}, new RoleDialogModel(AgentRole.User, string.Empty), stream, file.FileName ?? string.Empty);
33+
}, new RoleDialogModel(AgentRole.User, string.Empty), stream, image.FileName ?? string.Empty);
34+
3435
stream.Close();
36+
return message;
37+
}
38+
39+
public async Task<RoleDialogModel> EditImage(string? provider, string? model, string text, BotSharpFile image)
40+
{
41+
if (string.IsNullOrWhiteSpace(image?.FileUrl) && string.IsNullOrWhiteSpace(image?.FileData))
42+
{
43+
throw new ArgumentException($"Cannot find image url or data!");
44+
}
45+
46+
var completion = CompletionProvider.GetImageCompletion(_services, provider: provider ?? "openai", model: model ?? "dall-e-2");
47+
var bytes = await DownloadFile(image);
48+
using var stream = new MemoryStream();
49+
stream.Write(bytes, 0, bytes.Length);
50+
stream.Position = 0;
3551

52+
var message = await completion.GetImageEdits(new Agent()
53+
{
54+
Id = Guid.Empty.ToString()
55+
}, new RoleDialogModel(AgentRole.User, text), stream, image.FileName ?? string.Empty);
56+
57+
stream.Close();
58+
return message;
59+
}
60+
61+
public async Task<RoleDialogModel> EditImage(string? provider, string? model, string text, BotSharpFile image, BotSharpFile mask)
62+
{
63+
if ((string.IsNullOrWhiteSpace(image?.FileUrl) && string.IsNullOrWhiteSpace(image?.FileData)) ||
64+
(string.IsNullOrWhiteSpace(mask?.FileUrl) && string.IsNullOrWhiteSpace(mask?.FileData)))
65+
{
66+
throw new ArgumentException($"Cannot find image/mask url or data");
67+
}
68+
69+
var completion = CompletionProvider.GetImageCompletion(_services, provider: provider ?? "openai", model: model ?? "dall-e-2");
70+
var imageBytes = await DownloadFile(image);
71+
var maskBytes = await DownloadFile(mask);
72+
73+
using var imageStream = new MemoryStream();
74+
imageStream.Write(imageBytes, 0, imageBytes.Length);
75+
imageStream.Position = 0;
76+
77+
using var maskStream = new MemoryStream();
78+
maskStream.Write(maskBytes, 0, maskBytes.Length);
79+
maskStream.Position = 0;
80+
81+
var message = await completion.GetImageEdits(new Agent()
82+
{
83+
Id = Guid.Empty.ToString()
84+
}, new RoleDialogModel(AgentRole.User, text), imageStream, image.FileName ?? string.Empty, maskStream, mask.FileName ?? string.Empty);
85+
86+
imageStream.Close();
87+
maskStream.Close();
3688
return message;
3789
}
3890

src/Infrastructure/BotSharp.OpenAPI/Controllers/InstructModeController.cs

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ public async Task<string> TextCompletion([FromBody] IncomingMessageModel input)
5656
return await textCompletion.GetCompletion(input.Text, Guid.Empty.ToString(), Guid.NewGuid().ToString());
5757
}
5858

59+
#region Chat
5960
[HttpPost("/instruct/chat-completion")]
6061
public async Task<string> ChatCompletion([FromBody] IncomingMessageModel input)
6162
{
@@ -75,7 +76,9 @@ public async Task<string> ChatCompletion([FromBody] IncomingMessageModel input)
7576
});
7677
return message.Content;
7778
}
79+
#endregion
7880

81+
#region Read image
7982
[HttpPost("/instruct/multi-modal")]
8083
public async Task<string> MultiModalCompletion([FromBody] IncomingMessageModel input)
8184
{
@@ -105,7 +108,9 @@ public async Task<string> MultiModalCompletion([FromBody] IncomingMessageModel i
105108
return error;
106109
}
107110
}
111+
#endregion
108112

113+
#region Generate image
109114
[HttpPost("/instruct/image-generation")]
110115
public async Task<ImageGenerationViewModel> ImageGeneration([FromBody] IncomingMessageModel input)
111116
{
@@ -129,7 +134,9 @@ public async Task<ImageGenerationViewModel> ImageGeneration([FromBody] IncomingM
129134
return imageViewModel;
130135
}
131136
}
137+
#endregion
132138

139+
#region Edit image
133140
[HttpPost("/instruct/image-variation")]
134141
public async Task<ImageGenerationViewModel> ImageVariation([FromBody] IncomingMessageModel input)
135142
{
@@ -140,12 +147,12 @@ public async Task<ImageGenerationViewModel> ImageVariation([FromBody] IncomingMe
140147

141148
try
142149
{
143-
var file = input.Files.FirstOrDefault(x => !string.IsNullOrWhiteSpace(x.FileUrl) || !string.IsNullOrWhiteSpace(x.FileData));
144-
if (file == null)
150+
var image = input.Files.FirstOrDefault(x => !string.IsNullOrWhiteSpace(x.FileUrl) || !string.IsNullOrWhiteSpace(x.FileData));
151+
if (image == null)
145152
{
146153
return new ImageGenerationViewModel { Message = "Error! Cannot find an image!" };
147154
}
148-
var message = await fileService.VarifyImage(input.Provider, input.Model, file);
155+
var message = await fileService.VaryImage(input.Provider, input.Model, image);
149156
imageViewModel.Content = message.Content;
150157
imageViewModel.Images = message.GeneratedImages.Select(x => ImageViewModel.ToViewModel(x)).ToList();
151158
return imageViewModel;
@@ -159,6 +166,67 @@ public async Task<ImageGenerationViewModel> ImageVariation([FromBody] IncomingMe
159166
}
160167
}
161168

169+
[HttpPost("/instruct/image-edit")]
170+
public async Task<ImageGenerationViewModel> ImageEdit([FromBody] IncomingMessageModel input)
171+
{
172+
var fileService = _services.GetRequiredService<IBotSharpFileService>();
173+
var state = _services.GetRequiredService<IConversationStateService>();
174+
input.States.ForEach(x => state.SetState(x.Key, x.Value, activeRounds: x.ActiveRounds, source: StateSource.External));
175+
var imageViewModel = new ImageGenerationViewModel();
176+
177+
try
178+
{
179+
var image = input.Files.FirstOrDefault(x => !string.IsNullOrWhiteSpace(x.FileUrl) || !string.IsNullOrWhiteSpace(x.FileData));
180+
if (image == null)
181+
{
182+
return new ImageGenerationViewModel { Message = "Error! Cannot find an image!" };
183+
}
184+
var message = await fileService.EditImage(input.Provider, input.Model, input.Text, image);
185+
imageViewModel.Content = message.Content;
186+
imageViewModel.Images = message.GeneratedImages.Select(x => ImageViewModel.ToViewModel(x)).ToList();
187+
return imageViewModel;
188+
}
189+
catch (Exception ex)
190+
{
191+
var error = $"Error in image edit. {ex.Message}";
192+
_logger.LogError(error);
193+
imageViewModel.Message = error;
194+
return imageViewModel;
195+
}
196+
}
197+
198+
[HttpPost("/instruct/image-mask-edit")]
199+
public async Task<ImageGenerationViewModel> ImageMaskEdit([FromBody] IncomingMessageModel input)
200+
{
201+
var fileService = _services.GetRequiredService<IBotSharpFileService>();
202+
var state = _services.GetRequiredService<IConversationStateService>();
203+
input.States.ForEach(x => state.SetState(x.Key, x.Value, activeRounds: x.ActiveRounds, source: StateSource.External));
204+
var imageViewModel = new ImageGenerationViewModel();
205+
206+
try
207+
{
208+
var image = input.Files.FirstOrDefault(x => !string.IsNullOrWhiteSpace(x.FileUrl) || !string.IsNullOrWhiteSpace(x.FileData));
209+
var mask = input.Mask;
210+
if (image == null || mask == null)
211+
{
212+
return new ImageGenerationViewModel { Message = "Error! Cannot find an image or mask!" };
213+
}
214+
var message = await fileService.EditImage(input.Provider, input.Model, input.Text, image, mask);
215+
imageViewModel.Content = message.Content;
216+
imageViewModel.Images = message.GeneratedImages.Select(x => ImageViewModel.ToViewModel(x)).ToList();
217+
return imageViewModel;
218+
}
219+
catch (Exception ex)
220+
{
221+
var error = $"Error in image mask edit. {ex.Message}";
222+
_logger.LogError(error);
223+
imageViewModel.Message = error;
224+
return imageViewModel;
225+
}
226+
}
227+
#endregion
228+
229+
#region Pdf
162230
[HttpPost("/instruct/pdf-completion")]
163231
public async Task<PdfCompletionViewModel> PdfCompletion([FromBody] IncomingMessageModel input)
164232
{
@@ -181,4 +249,5 @@ public async Task<PdfCompletionViewModel> PdfCompletion([FromBody] IncomingMessa
181249
return viewModel;
182250
}
183251
}
252+
#endregion
184253
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
using OpenAI.Images;
2+
3+
namespace BotSharp.Plugin.AzureOpenAI.Providers.Image;
4+
5+
public partial class ImageCompletionProvider
6+
{
7+
public async Task<RoleDialogModel> GetImageEdits(Agent agent, RoleDialogModel message, Stream image, string imageFileName)
8+
{
9+
var client = ProviderHelper.GetClient(Provider, _model, _services);
10+
var (prompt, imageCount, options) = PrepareEditOptions(message);
11+
var imageClient = client.GetImageClient(_model);
12+
13+
var response = imageClient.GenerateImageEdits(image, imageFileName, prompt, imageCount, options);
14+
var images = response.Value;
15+
16+
var generatedImages = GetImageGenerations(images, options.ResponseFormat);
17+
var content = string.Join("\r\n", generatedImages.Where(x => !string.IsNullOrWhiteSpace(x.Description)).Select(x => x.Description));
18+
var responseMessage = new RoleDialogModel(AgentRole.Assistant, content)
19+
{
20+
CurrentAgentId = agent.Id,
21+
MessageId = message?.MessageId ?? string.Empty,
22+
GeneratedImages = generatedImages
23+
};
24+
25+
return await Task.FromResult(responseMessage);
26+
}
27+
28+
public async Task<RoleDialogModel> GetImageEdits(Agent agent, RoleDialogModel message,
29+
Stream image, string imageFileName, Stream mask, string maskFileName)
30+
{
31+
var client = ProviderHelper.GetClient(Provider, _model, _services);
32+
var (prompt, imageCount, options) = PrepareEditOptions(message);
33+
var imageClient = client.GetImageClient(_model);
34+
35+
var response = imageClient.GenerateImageEdits(image, imageFileName, prompt, mask, maskFileName, imageCount, options);
36+
var images = response.Value;
37+
38+
var generatedImages = GetImageGenerations(images, options.ResponseFormat);
39+
var content = string.Join("\r\n", generatedImages.Where(x => !string.IsNullOrWhiteSpace(x.Description)).Select(x => x.Description));
40+
var responseMessage = new RoleDialogModel(AgentRole.Assistant, content)
41+
{
42+
CurrentAgentId = agent.Id,
43+
MessageId = message?.MessageId ?? string.Empty,
44+
GeneratedImages = generatedImages
45+
};
46+
47+
return await Task.FromResult(responseMessage);
48+
}
49+
50+
private (string, int, ImageEditOptions) PrepareEditOptions(RoleDialogModel message)
51+
{
52+
var prompt = message?.Payload ?? message?.Content ?? string.Empty;
53+
54+
var state = _services.GetRequiredService<IConversationStateService>();
55+
var size = GetImageSize(state.GetState("image_size"));
56+
var format = GetImageFormat(state.GetState("image_format"));
57+
var count = GetImageCount(state.GetState("image_count", "1"));
58+
59+
var options = new ImageEditOptions
60+
{
61+
Size = size,
62+
ResponseFormat = format
63+
};
64+
return (prompt, count, options);
65+
}
66+
}

src/Plugins/BotSharp.Plugin.AzureOpenAI/Providers/Image/ImageCompletionProvider.Generation.cs

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,36 +7,13 @@ public partial class ImageCompletionProvider
77
public async Task<RoleDialogModel> GetImageGeneration(Agent agent, RoleDialogModel message)
88
{
99
var client = ProviderHelper.GetClient(Provider, _model, _services);
10-
var (prompt, imageCount, options) = PrepareOptions(message);
10+
var (prompt, imageCount, options) = PrepareGenerationOptions(message);
1111
var imageClient = client.GetImageClient(_model);
1212

1313
var response = imageClient.GenerateImages(prompt, imageCount, options);
14-
var values = response.Value;
15-
16-
var generatedImages = new List<ImageGeneration>();
17-
foreach (var value in values)
18-
{
19-
if (value == null) continue;
20-
21-
var generatedImage = new ImageGeneration { Description = value?.RevisedPrompt ?? string.Empty };
22-
if (options.ResponseFormat == GeneratedImageFormat.Uri)
23-
{
24-
generatedImage.ImageUrl = value?.ImageUri?.AbsoluteUri ?? string.Empty;
25-
}
26-
else if (options.ResponseFormat == GeneratedImageFormat.Bytes)
27-
{
28-
var base64Str = string.Empty;
29-
var bytes = value?.ImageBytes?.ToArray();
30-
if (!bytes.IsNullOrEmpty())
31-
{
32-
base64Str = Convert.ToBase64String(bytes);
33-
}
34-
generatedImage.ImageData = base64Str;
35-
}
36-
37-
generatedImages.Add(generatedImage);
38-
}
14+
var images = response.Value;
3915

16+
var generatedImages = GetImageGenerations(images, options.ResponseFormat);
4017
var content = string.Join("\r\n", generatedImages.Where(x => !string.IsNullOrWhiteSpace(x.Description)).Select(x => x.Description));
4118
var responseMessage = new RoleDialogModel(AgentRole.Assistant, content)
4219
{
@@ -48,7 +25,7 @@ public async Task<RoleDialogModel> GetImageGeneration(Agent agent, RoleDialogMod
4825
return await Task.FromResult(responseMessage);
4926
}
5027

51-
private (string, int, ImageGenerationOptions) PrepareOptions(RoleDialogModel message)
28+
private (string, int, ImageGenerationOptions) PrepareGenerationOptions(RoleDialogModel message)
5229
{
5330
var prompt = message?.Payload ?? message?.Content ?? string.Empty;
5431

0 commit comments

Comments
 (0)