|
| 1 | +using OpenAI.Images; |
| 2 | + |
| 3 | +namespace BotSharp.Plugin.AzureOpenAI.Providers.Image; |
| 4 | + |
| 5 | +public class ImageVariationProvider : IImageVariation |
| 6 | +{ |
| 7 | + protected readonly AzureOpenAiSettings _settings; |
| 8 | + protected readonly IServiceProvider _services; |
| 9 | + protected readonly ILogger<ImageVariationProvider> _logger; |
| 10 | + |
| 11 | + private const int DEFAULT_IMAGE_COUNT = 1; |
| 12 | + private const int IMAGE_COUNT_LIMIT = 5; |
| 13 | + |
| 14 | + protected string _model; |
| 15 | + |
| 16 | + public virtual string Provider => "azure-openai"; |
| 17 | + |
| 18 | + public ImageVariationProvider( |
| 19 | + AzureOpenAiSettings settings, |
| 20 | + ILogger<ImageVariationProvider> logger, |
| 21 | + IServiceProvider services) |
| 22 | + { |
| 23 | + _settings = settings; |
| 24 | + _services = services; |
| 25 | + _logger = logger; |
| 26 | + } |
| 27 | + |
| 28 | + public async Task<RoleDialogModel> GetImageVariation(Agent agent, RoleDialogModel message, Stream image, string imageFileName) |
| 29 | + { |
| 30 | + var client = ProviderHelper.GetClient(Provider, _model, _services); |
| 31 | + var (imageCount, options) = PrepareOptions(); |
| 32 | + var imageClient = client.GetImageClient(_model); |
| 33 | + |
| 34 | + var response = imageClient.GenerateImageVariations(image, imageFileName, imageCount, options); |
| 35 | + var values = response.Value; |
| 36 | + |
| 37 | + var generatedImages = new List<ImageGeneration>(); |
| 38 | + foreach (var value in values) |
| 39 | + { |
| 40 | + if (value == null) continue; |
| 41 | + |
| 42 | + var generatedImage = new ImageGeneration { Description = value?.RevisedPrompt ?? string.Empty }; |
| 43 | + if (options.ResponseFormat == GeneratedImageFormat.Uri) |
| 44 | + { |
| 45 | + generatedImage.ImageUrl = value?.ImageUri?.AbsoluteUri ?? string.Empty; |
| 46 | + } |
| 47 | + else if (options.ResponseFormat == GeneratedImageFormat.Bytes) |
| 48 | + { |
| 49 | + var base64Str = string.Empty; |
| 50 | + var bytes = value?.ImageBytes?.ToArray(); |
| 51 | + if (!bytes.IsNullOrEmpty()) |
| 52 | + { |
| 53 | + base64Str = Convert.ToBase64String(bytes); |
| 54 | + } |
| 55 | + generatedImage.ImageData = base64Str; |
| 56 | + } |
| 57 | + |
| 58 | + generatedImages.Add(generatedImage); |
| 59 | + } |
| 60 | + |
| 61 | + var content = string.Join("\r\n", generatedImages.Where(x => !string.IsNullOrWhiteSpace(x.Description)).Select(x => x.Description)); |
| 62 | + var responseMessage = new RoleDialogModel(AgentRole.Assistant, content) |
| 63 | + { |
| 64 | + CurrentAgentId = agent.Id, |
| 65 | + MessageId = message?.MessageId ?? string.Empty, |
| 66 | + GeneratedImages = generatedImages |
| 67 | + }; |
| 68 | + |
| 69 | + return await Task.FromResult(responseMessage); |
| 70 | + } |
| 71 | + |
| 72 | + public void SetModelName(string model) |
| 73 | + { |
| 74 | + _model = model; |
| 75 | + } |
| 76 | + |
| 77 | + private (int, ImageVariationOptions) PrepareOptions() |
| 78 | + { |
| 79 | + var state = _services.GetRequiredService<IConversationStateService>(); |
| 80 | + var size = state.GetState("image_size"); |
| 81 | + var format = state.GetState("image_format"); |
| 82 | + var count = GetImageCount(state.GetState("image_count", "1")); |
| 83 | + |
| 84 | + var options = new ImageVariationOptions |
| 85 | + { |
| 86 | + Size = GetImageSize(size), |
| 87 | + ResponseFormat = GetImageFormat(format) |
| 88 | + }; |
| 89 | + return (count, options); |
| 90 | + } |
| 91 | + |
| 92 | + private GeneratedImageSize GetImageSize(string size) |
| 93 | + { |
| 94 | + var value = !string.IsNullOrEmpty(size) ? size : "1024x1024"; |
| 95 | + |
| 96 | + GeneratedImageSize retSize; |
| 97 | + switch (value) |
| 98 | + { |
| 99 | + case "256x256": |
| 100 | + retSize = GeneratedImageSize.W256xH256; |
| 101 | + break; |
| 102 | + case "512x512": |
| 103 | + retSize = GeneratedImageSize.W512xH512; |
| 104 | + break; |
| 105 | + case "1024x1024": |
| 106 | + retSize = GeneratedImageSize.W1024xH1024; |
| 107 | + break; |
| 108 | + case "1024x1792": |
| 109 | + retSize = GeneratedImageSize.W1024xH1792; |
| 110 | + break; |
| 111 | + case "1792x1024": |
| 112 | + retSize = GeneratedImageSize.W1792xH1024; |
| 113 | + break; |
| 114 | + default: |
| 115 | + retSize = GeneratedImageSize.W1024xH1024; |
| 116 | + break; |
| 117 | + } |
| 118 | + |
| 119 | + return retSize; |
| 120 | + } |
| 121 | + |
| 122 | + private GeneratedImageFormat GetImageFormat(string format) |
| 123 | + { |
| 124 | + var value = !string.IsNullOrEmpty(format) ? format : "uri"; |
| 125 | + |
| 126 | + GeneratedImageFormat retFormat; |
| 127 | + switch (value) |
| 128 | + { |
| 129 | + case "uri": |
| 130 | + retFormat = GeneratedImageFormat.Uri; |
| 131 | + break; |
| 132 | + case "bytes": |
| 133 | + retFormat = GeneratedImageFormat.Bytes; |
| 134 | + break; |
| 135 | + default: |
| 136 | + retFormat = GeneratedImageFormat.Uri; |
| 137 | + break; |
| 138 | + } |
| 139 | + |
| 140 | + return retFormat; |
| 141 | + } |
| 142 | + |
| 143 | + private int GetImageCount(string count) |
| 144 | + { |
| 145 | + if (!int.TryParse(count, out var retCount)) |
| 146 | + { |
| 147 | + return DEFAULT_IMAGE_COUNT; |
| 148 | + } |
| 149 | + |
| 150 | + return retCount > 0 && retCount <= IMAGE_COUNT_LIMIT ? retCount : DEFAULT_IMAGE_COUNT; |
| 151 | + } |
| 152 | +} |
0 commit comments