|
1 |
| -using Fritz.StreamLib.Core; |
2 |
| -using Microsoft.Azure.CognitiveServices.Vision.CustomVision.Prediction; |
3 |
| -using Microsoft.Azure.CognitiveServices.Vision.CustomVision.Prediction.Models; |
| 1 | +using Azure.AI.OpenAI; |
| 2 | +using Fritz.StreamLib.Core; |
| 3 | +using Fritz.StreamTools.Hubs; |
| 4 | +using Microsoft.AspNetCore.SignalR; |
| 5 | +using Microsoft.Extensions.AI; |
4 | 6 | using Microsoft.Extensions.Configuration;
|
5 | 7 | using System;
|
6 |
| -using System.Linq; |
| 8 | +using System.ClientModel; |
7 | 9 | using System.Collections.Generic;
|
8 |
| -using System.Collections.Immutable; |
9 |
| -using System.Text; |
| 10 | +using System.Diagnostics; |
| 11 | +using System.IO; |
10 | 12 | using System.Threading.Tasks;
|
11 |
| -using Microsoft.Azure.CognitiveServices.Vision.CustomVision.Training; |
12 |
| -using System.Net; |
13 |
| -using Microsoft.AspNetCore.SignalR; |
14 |
| -using Fritz.StreamTools.Hubs; |
15 | 13 |
|
16 | 14 | namespace Fritz.Chatbot.Commands
|
17 | 15 | {
|
18 | 16 | public class PredictHatCommand : IBasicCommand
|
19 | 17 | {
|
20 | 18 | public string Trigger => "hat";
|
21 |
| - public string Description => "Identify which hat Fritz is wearing"; |
| 19 | + public string Description => "Describe which hat Fritz is wearing"; |
22 | 20 | public TimeSpan? Cooldown => TimeSpan.FromSeconds(30);
|
23 | 21 |
|
24 |
| - private string _CustomVisionKey = ""; |
25 |
| - private string _AzureEndpoint = ""; |
26 |
| - private string _TwitchChannel = ""; |
27 |
| - private Guid _AzureProjectId; |
| 22 | + private DateTimeOffset? _LastRun = DateTimeOffset.MinValue; |
| 23 | + private TimeSpan _LastRunCooldown = TimeSpan.FromMinutes(5); |
| 24 | + private HatMetadata _LastHatDetected = null!; |
28 | 25 |
|
29 |
| - internal static string IterationName = ""; |
30 |
| - private ScreenshotTrainingService _TrainHat; |
31 |
| - private readonly HatDescriptionRepository _Repository; |
| 26 | + private ScreenshotTrainingService _ObsConnection; |
| 27 | + //private readonly HatDescriptionRepository _Repository; |
32 | 28 | private readonly IHubContext<ObsHub> _HubContext;
|
| 29 | + private readonly string _AzureOpenAiEndpoint; |
| 30 | + private readonly string _AzureOpenAiKey; |
| 31 | + private readonly string _AzureOpenAiModel; |
33 | 32 |
|
34 |
| - public PredictHatCommand(IConfiguration configuration, ScreenshotTrainingService service, HatDescriptionRepository repository, IHubContext<ObsHub> hubContext) |
| 33 | + public PredictHatCommand(IConfiguration configuration, ScreenshotTrainingService service, IHubContext<ObsHub> hubContext) |
35 | 34 | {
|
36 |
| - _CustomVisionKey = configuration["AzureServices:HatDetection:Key"]; |
37 |
| - _AzureEndpoint = configuration["AzureServices:HatDetection:CustomVisionEndpoint"]; |
38 |
| - _TwitchChannel = configuration["StreamServices:Twitch:Channel"]; |
39 |
| - _AzureProjectId = Guid.Parse(configuration["AzureServices:HatDetection:ProjectId"]); |
40 |
| - _TrainHat = service; |
41 |
| - _Repository = repository; |
| 35 | + _AzureOpenAiEndpoint = configuration["AzureOpenAiEndpoint"]; |
| 36 | + _AzureOpenAiKey = configuration["AzureOpenAiKey"]; |
| 37 | + _AzureOpenAiModel = configuration["AzureOpenAiModel"]; |
| 38 | + _ObsConnection = service; |
42 | 39 | _HubContext = hubContext;
|
43 | 40 | }
|
44 | 41 |
|
45 | 42 | public async Task Execute(IChatService chatService, string userName, ReadOnlyMemory<char> rhs)
|
46 | 43 | {
|
47 | 44 |
|
48 |
| - if (string.IsNullOrEmpty(IterationName)) { |
49 |
| - await IdentifyIterationName(); |
50 |
| - } |
51 |
| - |
52 |
| - var client = new CustomVisionPredictionClient() |
| 45 | + if (!userName.Equals("csharpfritz", StringComparison.InvariantCultureIgnoreCase) |
| 46 | + && _LastRun.HasValue && DateTimeOffset.Now - _LastRun < _LastRunCooldown) |
53 | 47 | {
|
54 |
| - ApiKey = _CustomVisionKey, |
55 |
| - Endpoint = _AzureEndpoint, |
56 |
| - }; |
| 48 | + // send a message about the last hat detected |
| 49 | + await chatService.SendMessageAsync($"@{userName} from my last analysis {Math.Round(DateTimeOffset.Now.Subtract(_LastRun.Value).TotalMinutes)} minutes ago I can tell you Fritz's hat is:"); |
| 50 | + await chatService.SendMessageAsync(_LastHatDetected.Description); |
| 51 | + return; |
| 52 | + } |
57 | 53 |
|
58 |
| - await _HubContext.Clients.All.SendAsync("shutter"); |
59 |
| - var obsImage = await _TrainHat.GetScreenshotFromObs(); |
| 54 | + var obsImage = await _ObsConnection.GetScreenshotFromObs(); |
60 | 55 |
|
61 | 56 | ////////////////////////////
|
62 | 57 |
|
63 |
| - ImagePrediction result; |
| 58 | + |
| 59 | + HatMetadata hatDetected; |
| 60 | + |
64 | 61 | try
|
65 | 62 | {
|
66 |
| - result = await client.DetectImageWithNoStoreAsync(_AzureProjectId, IterationName, obsImage); |
67 |
| - } catch (CustomVisionErrorException ex) { |
| 63 | + hatDetected = await PredictHat(obsImage); |
| 64 | + } |
| 65 | + catch |
| 66 | + { |
| 67 | + await chatService.SendMessageAsync("There was an error detecting this hat. Please try again in 30 seconds"); |
| 68 | + return; |
| 69 | + } |
68 | 70 |
|
69 |
| - |
| 71 | + if (string.IsNullOrEmpty(hatDetected.Name)) |
| 72 | + { |
| 73 | + await chatService.SendMessageAsync("There was an error detecting a hat. Please try again in 30 seconds"); |
| 74 | + return; |
| 75 | + } |
70 | 76 |
|
71 |
| - if (ex.Response.StatusCode == HttpStatusCode.NotFound) { |
72 |
| - await IdentifyIterationName(); |
73 |
| - } |
| 77 | + // send a message about the hat detected |
| 78 | + await chatService.SendMessageAsync($"@{userName} I can tell you Fritz's hat is: {hatDetected.Name}"); |
| 79 | + await chatService.SendMessageAsync(hatDetected.Description); |
| 80 | + await chatService.SendMessageAsync(hatDetected.Conclusion); |
| 81 | + _LastHatDetected = hatDetected; |
| 82 | + _LastRun = DateTimeOffset.Now; |
74 | 83 |
|
75 |
| - await chatService.SendMessageAsync("Unable to detect Fritz's hat right now... please try again in 1 minute"); |
76 |
| - return; |
| 84 | + } |
77 | 85 |
|
78 |
| - } |
| 86 | + private async Task<HatMetadata> PredictHat(Stream obsImage) |
| 87 | + { |
79 | 88 |
|
80 |
| - if (DateTime.UtcNow.Subtract(result.Created).TotalSeconds > Cooldown.Value.TotalSeconds) { |
81 |
| - await chatService.SendMessageAsync($"I previously predicted this hat about {DateTime.UtcNow.Subtract(result.Created).TotalSeconds} seconds ago"); |
82 |
| - } |
| 89 | + var client = new AzureOpenAIClient( |
| 90 | + new Uri(_AzureOpenAiEndpoint), |
| 91 | + new ApiKeyCredential(_AzureOpenAiKey)) |
| 92 | + .AsChatClient(_AzureOpenAiModel); |
83 | 93 |
|
84 |
| - var bestMatch = result.Predictions.OrderByDescending(p => p.Probability).FirstOrDefault(); |
85 |
| - if (bestMatch == null || bestMatch.Probability < 0.7D) { |
86 |
| - await chatService.SendMessageAsync("csharpAngry 404 Hat Not Found! Let's ask a moderator to !addhat so we can identify it next time"); |
87 |
| - // do we store the image? |
88 |
| - return; |
89 |
| - } |
| 94 | + var systemPrompt = |
| 95 | + """ |
| 96 | + You are an AI assistant that can detect and describe baseball-style |
| 97 | + hats from an image. The user will provide an image of someone wearing a hat |
| 98 | + and they might be also wearing a headset. Ignore the headset and focus on the hat. |
90 | 99 |
|
91 |
| - var hatData = (await _Repository.GetHatData(bestMatch.TagName)); |
92 |
| - var nameToReport = (hatData == null ? bestMatch.TagName : (string.IsNullOrEmpty(hatData.Name) ? bestMatch.TagName : hatData.Name)); |
93 |
| - await chatService.SendMessageAsync($"csharpClip I think (with {bestMatch.Probability.ToString("0.0%")} certainty) Jeff is currently wearing his {nameToReport} hat csharpClip"); |
94 |
| - if (hatData != null && !string.IsNullOrEmpty(hatData.Description)) await chatService.SendMessageAsync(hatData.Description); |
| 100 | + This hat comes from a collection that includes hats that fall into one of these categories: |
| 101 | + Sports Teams, Colleges, Marvel Comics, Microsoft logos, Star Wars, |
| 102 | + Video Games, and other popular culture references. |
95 | 103 |
|
96 |
| - await _HubContext.Clients.All.SendAsync("hatDetected", bestMatch.Probability.ToString("0.0%"), bestMatch.TagName, nameToReport, hatData?.Description); |
| 104 | + Be as descriptive as possible about the hat focusing on its design, colors, |
| 105 | + logo placement, text, category, organization that the hat references, and any notable features. |
97 | 106 |
|
98 |
| - } |
| 107 | + Limit the description to 3 sentences. |
| 108 | +
|
| 109 | + Suggest a name for the hat based on the color, organization, or logo. |
| 110 | +
|
| 111 | + If the hat references a sports team or college, you should conclude with that school's slogan or cheer. |
| 112 | +
|
| 113 | + If the hat references a Marvel comic character, you should conclude with a comment about the hat in the tone of the references character or organization |
| 114 | + """; |
99 | 115 |
|
100 |
| - private async Task IdentifyIterationName() |
| 116 | + var file = new byte[obsImage.Length]; |
| 117 | + obsImage.ReadExactly(file, 0, (int)obsImage.Length); |
| 118 | + var messages = new List<ChatMessage> |
101 | 119 | {
|
| 120 | + new ChatMessage(ChatRole.System, systemPrompt), |
| 121 | + new ChatMessage(ChatRole.User, new AIContent[] { |
| 122 | + new ImageContent(file, "image/webp"), // |
| 123 | + new TextContent("Generate a description of the hat"), // Generate a description of the hat |
| 124 | + }) |
| 125 | + }; |
102 | 126 |
|
103 |
| - var client = new CustomVisionTrainingClient() { |
104 |
| - ApiKey = _CustomVisionKey, |
105 |
| - Endpoint = _AzureEndpoint |
106 |
| - }; |
| 127 | + var sw = Stopwatch.StartNew(); |
| 128 | + var response = await client.CompleteAsync<HatMetadata>(messages, options: new ChatOptions { Temperature = 0.1f }); |
107 | 129 |
|
108 |
| - var iterations = await client.GetIterationsAsync(_AzureProjectId); |
109 |
| - IterationName = iterations |
110 |
| - .Where(i => !string.IsNullOrEmpty(i.PublishName) && i.Status == "Completed") |
111 |
| - .OrderByDescending(i => i.LastModified).First().PublishName; |
| 130 | + Console.WriteLine($"Elapsed: {sw.Elapsed}"); |
| 131 | + Console.WriteLine($"Predicted hat: {response.Result}"); |
| 132 | + return response.Result; |
112 | 133 |
|
113 | 134 | }
|
| 135 | + |
114 | 136 | }
|
| 137 | + |
| 138 | + internal record HatMetadata( |
| 139 | + string Color, |
| 140 | + string Name, |
| 141 | + bool HasLogo, |
| 142 | + string? LogoShape, |
| 143 | + string Text, |
| 144 | + string Description, |
| 145 | + string Category, |
| 146 | + string LogoDescription, |
| 147 | + string Organization, |
| 148 | + string Conclusion); |
| 149 | + |
115 | 150 | }
|
0 commit comments