1+ // Licensed to the .NET Foundation under one or more agreements.
2+ // The .NET Foundation licenses this file to you under the MIT license.
3+ // See the LICENSE file in the project root for more information.
4+
5+ using System . Diagnostics ;
6+ using System . Net . Http . Json ;
7+ using Microsoft . Extensions . Logging ;
8+
9+ namespace DevProxy . Abstractions . LanguageModel ;
10+
11+ public class LMStudioLanguageModelClient ( LanguageModelConfiguration ? configuration , ILogger logger ) : ILanguageModelClient
12+ {
13+ private readonly LanguageModelConfiguration ? _configuration = configuration ;
14+ private readonly ILogger _logger = logger ;
15+ private bool ? _lmAvailable ;
16+ private readonly Dictionary < string , OpenAICompletionResponse > _cacheCompletion = [ ] ;
17+ private readonly Dictionary < ILanguageModelChatCompletionMessage [ ] , OpenAIChatCompletionResponse > _cacheChatCompletion = [ ] ;
18+
19+ public async Task < bool > IsEnabledAsync ( )
20+ {
21+ if ( _lmAvailable . HasValue )
22+ {
23+ return _lmAvailable . Value ;
24+ }
25+
26+ _lmAvailable = await IsEnabledInternalAsync ( ) ;
27+ return _lmAvailable . Value ;
28+ }
29+
30+ private async Task < bool > IsEnabledInternalAsync ( )
31+ {
32+ if ( _configuration is null || ! _configuration . Enabled )
33+ {
34+ return false ;
35+ }
36+
37+ if ( string . IsNullOrEmpty ( _configuration . Url ) )
38+ {
39+ _logger . LogError ( "URL is not set. Language model will be disabled" ) ;
40+ return false ;
41+ }
42+
43+ if ( string . IsNullOrEmpty ( _configuration . Model ) )
44+ {
45+ _logger . LogError ( "Model is not set. Language model will be disabled" ) ;
46+ return false ;
47+ }
48+
49+ _logger . LogDebug ( "Checking LM availability at {url}..." , _configuration . Url ) ;
50+
51+ try
52+ {
53+ // check if lm is on
54+ using var client = new HttpClient ( ) ;
55+ var response = await client . GetAsync ( $ "{ _configuration . Url } /v1/models") ;
56+ _logger . LogDebug ( "Response: {response}" , response . StatusCode ) ;
57+
58+ if ( ! response . IsSuccessStatusCode )
59+ {
60+ return false ;
61+ }
62+
63+ var testCompletion = await GenerateCompletionInternalAsync ( "Are you there? Reply with a yes or no." ) ;
64+ if ( testCompletion ? . Error is not null )
65+ {
66+ _logger . LogError ( "Error: {error}. Param: {param}" , testCompletion . Error . Message , testCompletion . Error . Param ) ;
67+ return false ;
68+ }
69+
70+ return true ;
71+ }
72+ catch ( Exception ex )
73+ {
74+ _logger . LogError ( ex , "Couldn't reach language model at {url}" , _configuration . Url ) ;
75+ return false ;
76+ }
77+ }
78+
79+ public async Task < ILanguageModelCompletionResponse ? > GenerateCompletionAsync ( string prompt , CompletionOptions ? options = null )
80+ {
81+ using var scope = _logger . BeginScope ( nameof ( LMStudioLanguageModelClient ) ) ;
82+
83+ if ( _configuration is null )
84+ {
85+ return null ;
86+ }
87+
88+ if ( ! _lmAvailable . HasValue )
89+ {
90+ _logger . LogError ( "Language model availability is not checked. Call {isEnabled} first." , nameof ( IsEnabledAsync ) ) ;
91+ return null ;
92+ }
93+
94+ if ( ! _lmAvailable . Value )
95+ {
96+ return null ;
97+ }
98+
99+ if ( _configuration . CacheResponses && _cacheCompletion . TryGetValue ( prompt , out var cachedResponse ) )
100+ {
101+ _logger . LogDebug ( "Returning cached response for prompt: {prompt}" , prompt ) ;
102+ return cachedResponse ;
103+ }
104+
105+ var response = await GenerateCompletionInternalAsync ( prompt , options ) ;
106+ if ( response == null )
107+ {
108+ return null ;
109+ }
110+ if ( response . Error is not null )
111+ {
112+ _logger . LogError ( "Error: {error}. Param: {param}" , response . Error . Message , response . Error . Param ) ;
113+ return null ;
114+ }
115+ else
116+ {
117+ if ( _configuration . CacheResponses && response . Response is not null )
118+ {
119+ _cacheCompletion [ prompt ] = response ;
120+ }
121+
122+ return response ;
123+ }
124+ }
125+
126+ private async Task < OpenAICompletionResponse ? > GenerateCompletionInternalAsync ( string prompt , CompletionOptions ? options = null )
127+ {
128+ Debug . Assert ( _configuration != null , "Configuration is null" ) ;
129+
130+ try
131+ {
132+ using var client = new HttpClient ( ) ;
133+ var url = $ "{ _configuration . Url } /v1/completions";
134+ _logger . LogDebug ( "Requesting completion. Prompt: {prompt}" , prompt ) ;
135+
136+ var response = await client . PostAsJsonAsync ( url ,
137+ new
138+ {
139+ prompt ,
140+ model = _configuration . Model ,
141+ stream = false ,
142+ temperature = options ? . Temperature ?? 0.8 ,
143+ }
144+ ) ;
145+ _logger . LogDebug ( "Response: {response}" , response . StatusCode ) ;
146+
147+ var res = await response . Content . ReadFromJsonAsync < OpenAICompletionResponse > ( ) ;
148+ if ( res is null )
149+ {
150+ return res ;
151+ }
152+ res . RequestUrl = url ;
153+ return res ;
154+ }
155+ catch ( Exception ex )
156+ {
157+ _logger . LogError ( ex , "Failed to generate completion" ) ;
158+ return null ;
159+ }
160+ }
161+
162+ public async Task < ILanguageModelCompletionResponse ? > GenerateChatCompletionAsync ( ILanguageModelChatCompletionMessage [ ] messages )
163+ {
164+ using var scope = _logger . BeginScope ( nameof ( LMStudioLanguageModelClient ) ) ;
165+
166+ if ( _configuration is null )
167+ {
168+ return null ;
169+ }
170+
171+ if ( ! _lmAvailable . HasValue )
172+ {
173+ _logger . LogError ( "Language model availability is not checked. Call {isEnabled} first." , nameof ( IsEnabledAsync ) ) ;
174+ return null ;
175+ }
176+
177+ if ( ! _lmAvailable . Value )
178+ {
179+ return null ;
180+ }
181+
182+ if ( _configuration . CacheResponses && _cacheChatCompletion . TryGetValue ( messages , out var cachedResponse ) )
183+ {
184+ _logger . LogDebug ( "Returning cached response for message: {lastMessage}" , messages . Last ( ) . Content ) ;
185+ return cachedResponse ;
186+ }
187+
188+ var response = await GenerateChatCompletionInternalAsync ( messages ) ;
189+ if ( response == null )
190+ {
191+ return null ;
192+ }
193+ if ( response . Error is not null )
194+ {
195+ _logger . LogError ( "Error: {error}. Param: {param}" , response . Error . Message , response . Error . Param ) ;
196+ return null ;
197+ }
198+ else
199+ {
200+ if ( _configuration . CacheResponses && response . Response is not null )
201+ {
202+ _cacheChatCompletion [ messages ] = response ;
203+ }
204+
205+ return response ;
206+ }
207+ }
208+
209+ private async Task < OpenAIChatCompletionResponse ? > GenerateChatCompletionInternalAsync ( ILanguageModelChatCompletionMessage [ ] messages )
210+ {
211+ Debug . Assert ( _configuration != null , "Configuration is null" ) ;
212+
213+ try
214+ {
215+ using var client = new HttpClient ( ) ;
216+ var url = $ "{ _configuration . Url } /v1/chat/completions";
217+ _logger . LogDebug ( "Requesting chat completion. Message: {lastMessage}" , messages . Last ( ) . Content ) ;
218+
219+ var response = await client . PostAsJsonAsync ( url ,
220+ new
221+ {
222+ messages ,
223+ model = _configuration . Model ,
224+ stream = false
225+ }
226+ ) ;
227+ _logger . LogDebug ( "Response: {response}" , response . StatusCode ) ;
228+
229+ var res = await response . Content . ReadFromJsonAsync < OpenAIChatCompletionResponse > ( ) ;
230+ if ( res is null )
231+ {
232+ return res ;
233+ }
234+
235+ res . RequestUrl = url ;
236+ return res ;
237+ }
238+ catch ( Exception ex )
239+ {
240+ _logger . LogError ( ex , "Failed to generate chat completion" ) ;
241+ return null ;
242+ }
243+ }
244+ }
245+
246+ internal static class CacheChatCompletionExtensions
247+ {
248+ public static OpenAIChatCompletionMessage [ ] ? GetKey (
249+ this Dictionary < OpenAIChatCompletionMessage [ ] , OpenAIChatCompletionResponse > cache ,
250+ ILanguageModelChatCompletionMessage [ ] messages )
251+ {
252+ return cache . Keys . FirstOrDefault ( k => k . SequenceEqual ( messages ) ) ;
253+ }
254+
255+ public static bool TryGetValue (
256+ this Dictionary < OpenAIChatCompletionMessage [ ] , OpenAIChatCompletionResponse > cache ,
257+ ILanguageModelChatCompletionMessage [ ] messages , out OpenAIChatCompletionResponse ? value )
258+ {
259+ var key = cache . GetKey ( messages ) ;
260+ if ( key is null )
261+ {
262+ value = null ;
263+ return false ;
264+ }
265+
266+ value = cache [ key ] ;
267+ return true ;
268+ }
269+ }
0 commit comments