|
1 |
| -using Fritz.StreamLib.Core; |
| 1 | +using Azure.AI.OpenAI; |
| 2 | +using Fritz.StreamLib.Core; |
2 | 3 | using Fritz.StreamTools.Hubs;
|
3 | 4 | using Microsoft.AspNetCore.SignalR;
|
4 |
| -using Microsoft.Azure.CognitiveServices.Vision.CustomVision.Prediction; |
| 5 | +using Microsoft.Extensions.AI; |
5 | 6 | using Microsoft.Extensions.Configuration;
|
6 | 7 | using System;
|
| 8 | +using System.ClientModel; |
| 9 | +using System.Collections.Generic; |
| 10 | +using System.Diagnostics; |
| 11 | +using System.IO; |
7 | 12 | using System.Threading.Tasks;
|
8 | 13 |
|
9 | 14 | namespace Fritz.Chatbot.Commands
|
10 | 15 | {
|
11 | 16 | public class PredictHatCommand : IBasicCommand
|
12 | 17 | {
|
13 | 18 | public string Trigger => "hat";
|
14 |
| - public string Description => "Identify which hat Fritz is wearing"; |
| 19 | + public string Description => "Describe which hat Fritz is wearing"; |
15 | 20 | public TimeSpan? Cooldown => TimeSpan.FromSeconds(30);
|
16 | 21 |
|
17 |
| - private string _CustomVisionKey = ""; |
18 |
| - private string _AzureEndpoint = ""; |
19 |
| - private string _TwitchChannel = ""; |
20 |
| - private Guid _AzureProjectId; |
| 22 | + private DateTimeOffset? _LastRun = DateTimeOffset.MinValue; |
| 23 | + private TimeSpan _LastRunCooldown = TimeSpan.FromMinutes(5); |
| 24 | + private HatMetadata _LastHatDetected = null!; |
21 | 25 |
|
22 |
| - internal static string IterationName = ""; |
23 |
| - private ScreenshotTrainingService _TrainHat; |
| 26 | + private ScreenshotTrainingService _ObsConnection; |
24 | 27 | //private readonly HatDescriptionRepository _Repository;
|
25 | 28 | private readonly IHubContext<ObsHub> _HubContext;
|
| 29 | + private readonly string _AzureOpenAiEndpoint; |
| 30 | + private readonly string _AzureOpenAiKey; |
| 31 | + private readonly string _AzureOpenAiModel; |
26 | 32 |
|
27 | 33 | public PredictHatCommand(IConfiguration configuration, ScreenshotTrainingService service, IHubContext<ObsHub> hubContext)
|
28 | 34 | {
|
29 |
| - _CustomVisionKey = configuration["AzureServices:HatDetection:Key"]; |
30 |
| - _AzureEndpoint = configuration["AzureServices:HatDetection:CustomVisionEndpoint"]; |
31 |
| - _TwitchChannel = configuration["StreamServices:Twitch:Channel"]; |
32 |
| - _AzureProjectId = Guid.Parse(configuration["AzureServices:HatDetection:ProjectId"]); |
33 |
| - _TrainHat = service; |
34 |
| - //_Repository = repository; |
| 35 | + _AzureOpenAiEndpoint = configuration["AzureOpenAiEndpoint"]; |
| 36 | + _AzureOpenAiKey = configuration["AzureOpenAiKey"]; |
| 37 | + _AzureOpenAiModel = configuration["AzureOpenAiModel"]; |
| 38 | + _ObsConnection = service; |
35 | 39 | _HubContext = hubContext;
|
36 | 40 | }
|
37 | 41 |
|
38 | 42 | public async Task Execute(IChatService chatService, string userName, ReadOnlyMemory<char> rhs)
|
39 | 43 | {
|
40 | 44 |
|
41 |
| - var client = new CustomVisionPredictionClient() |
| 45 | + if (!userName.Equals("csharpfritz", StringComparison.InvariantCultureIgnoreCase) |
| 46 | + && _LastRun.HasValue && DateTimeOffset.Now - _LastRun < _LastRunCooldown) |
42 | 47 | {
|
43 |
| - ApiKey = _CustomVisionKey, |
44 |
| - Endpoint = _AzureEndpoint, |
45 |
| - }; |
| 48 | + // send a message about the last hat detected |
| 49 | + await chatService.SendMessageAsync($"@{userName} from my last analysis {Math.Round(DateTimeOffset.Now.Subtract(_LastRun.Value).TotalMinutes)} minutes ago I can tell you Fritz's hat is:"); |
| 50 | + await chatService.SendMessageAsync(_LastHatDetected.Description); |
| 51 | + return; |
| 52 | + } |
46 | 53 |
|
47 |
| - await _HubContext.Clients.All.SendAsync("shutter"); |
48 |
| - var obsImage = await _TrainHat.GetScreenshotFromObs(); |
| 54 | + var obsImage = await _ObsConnection.GetScreenshotFromObs(); |
49 | 55 |
|
50 | 56 | ////////////////////////////
|
51 | 57 |
|
52 | 58 |
|
| 59 | + HatMetadata hatDetected; |
| 60 | + |
| 61 | + try |
| 62 | + { |
| 63 | + hatDetected = await PredictHat(obsImage); |
| 64 | + } |
| 65 | + catch |
| 66 | + { |
| 67 | + await chatService.SendMessageAsync("There was an error detecting this hat. Please try again in 30 seconds"); |
| 68 | + return; |
| 69 | + } |
| 70 | + |
| 71 | + if (string.IsNullOrEmpty(hatDetected.Name)) |
| 72 | + { |
| 73 | + await chatService.SendMessageAsync("There was an error detecting a hat. Please try again in 30 seconds"); |
| 74 | + return; |
| 75 | + } |
| 76 | + |
| 77 | + // send a message about the hat detected |
| 78 | + await chatService.SendMessageAsync($"@{userName} I can tell you Fritz's hat is: {hatDetected.Name}"); |
| 79 | + await chatService.SendMessageAsync(hatDetected.Description); |
| 80 | + await chatService.SendMessageAsync(hatDetected.Conclusion); |
| 81 | + _LastHatDetected = hatDetected; |
| 82 | + _LastRun = DateTimeOffset.Now; |
| 83 | + |
| 84 | + } |
| 85 | + |
| 86 | + private async Task<HatMetadata> PredictHat(Stream obsImage) |
| 87 | + { |
| 88 | + |
| 89 | + var client = new AzureOpenAIClient( |
| 90 | + new Uri(_AzureOpenAiEndpoint), |
| 91 | + new ApiKeyCredential(_AzureOpenAiKey)) |
| 92 | + .AsChatClient(_AzureOpenAiModel); |
| 93 | + |
| 94 | + var systemPrompt = |
| 95 | + """ |
| 96 | + You are an AI assistant that can detect and describe baseball-style |
| 97 | + hats from an image. The user will provide an image of someone wearing a hat |
| 98 | + and they might be also wearing a headset. Ignore the headset and focus on the hat. |
| 99 | +
|
| 100 | + This hat comes from a collection that includes hats that fall into one of these categories: |
| 101 | + Sports Teams, Colleges, Marvel Comics, Microsoft logos, Star Wars, |
| 102 | + Video Games, and other popular culture references. |
| 103 | +
|
| 104 | + Be as descriptive as possible about the hat focusing on its design, colors, |
| 105 | + logo placement, text, category, organization that the hat references, and any notable features. |
| 106 | +
|
| 107 | + Limit the description to 3 sentences. |
| 108 | +
|
| 109 | + Suggest a name for the hat based on the color, organization, or logo. |
| 110 | +
|
| 111 | + If the hat references a sports team or college, you should conclude with that school's slogan or cheer. |
| 112 | +
|
| 113 | + If the hat references a Marvel comic character, you should conclude with a comment about the hat in the tone of the references character or organization |
| 114 | + """; |
| 115 | + |
| 116 | + var file = new byte[obsImage.Length]; |
| 117 | + obsImage.ReadExactly(file, 0, (int)obsImage.Length); |
| 118 | + var messages = new List<ChatMessage> |
| 119 | + { |
| 120 | + new ChatMessage(ChatRole.System, systemPrompt), |
| 121 | + new ChatMessage(ChatRole.User, new AIContent[] { |
| 122 | + new ImageContent(file, "image/webp"), // |
| 123 | + new TextContent("Generate a description of the hat"), // Generate a description of the hat |
| 124 | + }) |
| 125 | + }; |
| 126 | + |
| 127 | + var sw = Stopwatch.StartNew(); |
| 128 | + var response = await client.CompleteAsync<HatMetadata>(messages, options: new ChatOptions { Temperature = 0.1f }); |
| 129 | + |
| 130 | + Console.WriteLine($"Elapsed: {sw.Elapsed}"); |
| 131 | + Console.WriteLine($"Predicted hat: {response.Result}"); |
| 132 | + return response.Result; |
| 133 | + |
53 | 134 | }
|
54 | 135 |
|
55 | 136 | }
|
56 | 137 |
|
57 |
| - internal record HatData(string Name, string Description); |
| 138 | + internal record HatMetadata( |
| 139 | + string Color, |
| 140 | + string Name, |
| 141 | + bool HasLogo, |
| 142 | + string? LogoShape, |
| 143 | + string Text, |
| 144 | + string Description, |
| 145 | + string Category, |
| 146 | + string LogoDescription, |
| 147 | + string Organization, |
| 148 | + string Conclusion); |
58 | 149 |
|
59 | 150 | }
|
0 commit comments