1+ // Copyright (c) Microsoft Corporation.
2+ // Licensed under the MIT License.
3+
4+ using System . Diagnostics ;
5+ using System . Net . Http . Json ;
6+ using Microsoft . Extensions . Logging ;
7+
8+ namespace Microsoft . DevProxy . Abstractions ;
9+
10+ public class OllamaLanguageModelClient ( LanguageModelConfiguration ? configuration , ILogger logger ) : ILanguageModelClient
11+ {
12+ private readonly LanguageModelConfiguration ? _configuration = configuration ;
13+ private readonly ILogger _logger = logger ;
14+ private bool ? _lmAvailable ;
15+ private Dictionary < string , OllamaLanguageModelCompletionResponse > _cacheCompletion = new ( ) ;
16+ private Dictionary < ILanguageModelChatCompletionMessage [ ] , OllamaLanguageModelChatCompletionResponse > _cacheChatCompletion = new ( ) ;
17+
18+ public async Task < bool > IsEnabled ( )
19+ {
20+ if ( _lmAvailable . HasValue )
21+ {
22+ return _lmAvailable . Value ;
23+ }
24+
25+ _lmAvailable = await IsEnabledInternal ( ) ;
26+ return _lmAvailable . Value ;
27+ }
28+
29+ private async Task < bool > IsEnabledInternal ( )
30+ {
31+ if ( _configuration is null || ! _configuration . Enabled )
32+ {
33+ return false ;
34+ }
35+
36+ if ( string . IsNullOrEmpty ( _configuration . Url ) )
37+ {
38+ _logger . LogError ( "URL is not set. Language model will be disabled" ) ;
39+ return false ;
40+ }
41+
42+ if ( string . IsNullOrEmpty ( _configuration . Model ) )
43+ {
44+ _logger . LogError ( "Model is not set. Language model will be disabled" ) ;
45+ return false ;
46+ }
47+
48+ _logger . LogDebug ( "Checking LM availability at {url}..." , _configuration . Url ) ;
49+
50+ try
51+ {
52+ // check if lm is on
53+ using var client = new HttpClient ( ) ;
54+ var response = await client . GetAsync ( _configuration . Url ) ;
55+ _logger . LogDebug ( "Response: {response}" , response . StatusCode ) ;
56+
57+ if ( ! response . IsSuccessStatusCode )
58+ {
59+ return false ;
60+ }
61+
62+ var testCompletion = await GenerateCompletionInternal ( "Are you there? Reply with a yes or no." ) ;
63+ if ( testCompletion ? . Error is not null )
64+ {
65+ _logger . LogError ( "Error: {error}" , testCompletion . Error ) ;
66+ return false ;
67+ }
68+
69+ return true ;
70+ }
71+ catch ( Exception ex )
72+ {
73+ _logger . LogError ( ex , "Couldn't reach language model at {url}" , _configuration . Url ) ;
74+ return false ;
75+ }
76+ }
77+
78+ public async Task < ILanguageModelCompletionResponse ? > GenerateCompletion ( string prompt )
79+ {
80+ using var scope = _logger . BeginScope ( nameof ( OllamaLanguageModelClient ) ) ;
81+
82+ if ( _configuration is null )
83+ {
84+ return null ;
85+ }
86+
87+ if ( ! _lmAvailable . HasValue )
88+ {
89+ _logger . LogError ( "Language model availability is not checked. Call {isEnabled} first." , nameof ( IsEnabled ) ) ;
90+ return null ;
91+ }
92+
93+ if ( ! _lmAvailable . Value )
94+ {
95+ return null ;
96+ }
97+
98+ if ( _configuration . CacheResponses && _cacheCompletion . TryGetValue ( prompt , out var cachedResponse ) )
99+ {
100+ _logger . LogDebug ( "Returning cached response for prompt: {prompt}" , prompt ) ;
101+ return cachedResponse ;
102+ }
103+
104+ var response = await GenerateCompletionInternal ( prompt ) ;
105+ if ( response == null )
106+ {
107+ return null ;
108+ }
109+ if ( response . Error is not null )
110+ {
111+ _logger . LogError ( response . Error ) ;
112+ return null ;
113+ }
114+ else
115+ {
116+ if ( _configuration . CacheResponses && response . Response is not null )
117+ {
118+ _cacheCompletion [ prompt ] = response ;
119+ }
120+
121+ return response ;
122+ }
123+ }
124+
125+ private async Task < OllamaLanguageModelCompletionResponse ? > GenerateCompletionInternal ( string prompt )
126+ {
127+ Debug . Assert ( _configuration != null , "Configuration is null" ) ;
128+
129+ try
130+ {
131+ using var client = new HttpClient ( ) ;
132+ var url = $ "{ _configuration . Url } /api/generate";
133+ _logger . LogDebug ( "Requesting completion. Prompt: {prompt}" , prompt ) ;
134+
135+ var response = await client . PostAsJsonAsync ( url ,
136+ new
137+ {
138+ prompt ,
139+ model = _configuration . Model ,
140+ stream = false
141+ }
142+ ) ;
143+ _logger . LogDebug ( "Response: {response}" , response . StatusCode ) ;
144+
145+ var res = await response . Content . ReadFromJsonAsync < OllamaLanguageModelCompletionResponse > ( ) ;
146+ if ( res is null )
147+ {
148+ return res ;
149+ }
150+
151+ res . RequestUrl = url ;
152+ return res ;
153+ }
154+ catch ( Exception ex )
155+ {
156+ _logger . LogError ( ex , "Failed to generate completion" ) ;
157+ return null ;
158+ }
159+ }
160+
161+ public async Task < ILanguageModelCompletionResponse ? > GenerateChatCompletion ( ILanguageModelChatCompletionMessage [ ] messages )
162+ {
163+ using var scope = _logger . BeginScope ( nameof ( OllamaLanguageModelClient ) ) ;
164+
165+ if ( _configuration is null )
166+ {
167+ return null ;
168+ }
169+
170+ if ( ! _lmAvailable . HasValue )
171+ {
172+ _logger . LogError ( "Language model availability is not checked. Call {isEnabled} first." , nameof ( IsEnabled ) ) ;
173+ return null ;
174+ }
175+
176+ if ( ! _lmAvailable . Value )
177+ {
178+ return null ;
179+ }
180+
181+ if ( _configuration . CacheResponses && _cacheChatCompletion . TryGetValue ( messages , out var cachedResponse ) )
182+ {
183+ _logger . LogDebug ( "Returning cached response for message: {lastMessage}" , messages . Last ( ) . Content ) ;
184+ return cachedResponse ;
185+ }
186+
187+ var response = await GenerateChatCompletionInternal ( messages ) ;
188+ if ( response == null )
189+ {
190+ return null ;
191+ }
192+ if ( response . Error is not null )
193+ {
194+ _logger . LogError ( response . Error ) ;
195+ return null ;
196+ }
197+ else
198+ {
199+ if ( _configuration . CacheResponses && response . Response is not null )
200+ {
201+ _cacheChatCompletion [ messages ] = response ;
202+ }
203+
204+ return response ;
205+ }
206+ }
207+
208+ private async Task < OllamaLanguageModelChatCompletionResponse ? > GenerateChatCompletionInternal ( ILanguageModelChatCompletionMessage [ ] messages )
209+ {
210+ Debug . Assert ( _configuration != null , "Configuration is null" ) ;
211+
212+ try
213+ {
214+ using var client = new HttpClient ( ) ;
215+ var url = $ "{ _configuration . Url } /api/chat";
216+ _logger . LogDebug ( "Requesting chat completion. Message: {lastMessage}" , messages . Last ( ) . Content ) ;
217+
218+ var response = await client . PostAsJsonAsync ( url ,
219+ new
220+ {
221+ messages ,
222+ model = _configuration . Model ,
223+ stream = false
224+ }
225+ ) ;
226+ _logger . LogDebug ( "Response: {response}" , response . StatusCode ) ;
227+
228+ var res = await response . Content . ReadFromJsonAsync < OllamaLanguageModelChatCompletionResponse > ( ) ;
229+ if ( res is null )
230+ {
231+ return res ;
232+ }
233+
234+ res . RequestUrl = url ;
235+ return res ;
236+ }
237+ catch ( Exception ex )
238+ {
239+ _logger . LogError ( ex , "Failed to generate chat completion" ) ;
240+ return null ;
241+ }
242+ }
243+ }
244+
245+ internal static class CacheChatCompletionExtensions
246+ {
247+ public static OllamaLanguageModelChatCompletionMessage [ ] ? GetKey (
248+ this Dictionary < OllamaLanguageModelChatCompletionMessage [ ] , OllamaLanguageModelChatCompletionResponse > cache ,
249+ ILanguageModelChatCompletionMessage [ ] messages )
250+ {
251+ return cache . Keys . FirstOrDefault ( k => k . SequenceEqual ( messages ) ) ;
252+ }
253+
254+ public static bool TryGetValue (
255+ this Dictionary < OllamaLanguageModelChatCompletionMessage [ ] , OllamaLanguageModelChatCompletionResponse > cache ,
256+ ILanguageModelChatCompletionMessage [ ] messages , out OllamaLanguageModelChatCompletionResponse ? value )
257+ {
258+ var key = cache . GetKey ( messages ) ;
259+ if ( key is null )
260+ {
261+ value = null ;
262+ return false ;
263+ }
264+
265+ value = cache [ key ] ;
266+ return true ;
267+ }
268+ }
0 commit comments