@@ -33,10 +33,8 @@ public class IntentClassifier
3333 private bool _isModelReady ;
3434 public bool isModelReady => _isModelReady ;
3535 private ClassifierSetting _settings ;
36-
3736 private string [ ] _labels ;
3837 public string [ ] Labels => GetLabels ( ) ;
39-
4038 private int _numLabels
4139 {
4240 get
@@ -111,7 +109,6 @@ private void Fit(NDArray x, NDArray y, TrainingParams trainingParams)
111109 batch_size : trainingParams . BatchSize ,
112110 epochs : trainingParams . Epochs ,
113111 callbacks : callbacks ,
114- // validation_split: 0.1f,
115112 shuffle : true ) ;
116113
117114 _model . save_weights ( weights ) ;
@@ -155,9 +152,18 @@ public NDArray GetTextEmbedding(string text)
155152
156153 public ( NDArray , NDArray ) PrepareLoadData ( )
157154 {
158- var agentService = _services . CreateScope ( ) . ServiceProvider . GetRequiredService < IAgentService > ( ) ;
159- string rootDirectory = Path . Combine ( agentService . GetDataDir ( ) , _settings . RAW_DATA_DIR ) ;
160- string saveLabelDirectory = Path . Combine ( agentService . GetDataDir ( ) , _settings . MODEL_DIR , _settings . LABEL_FILE_NAME ) ;
155+ var agentService = _services . CreateScope ( )
156+ . ServiceProvider
157+ . GetRequiredService < IAgentService > ( ) ;
158+ string rootDirectory = Path . Combine (
159+ agentService . GetDataDir ( ) ,
160+ _settings . RAW_DATA_DIR
161+ ) ;
162+ string saveLabelDirectory = Path . Combine (
163+ agentService . GetDataDir ( ) ,
164+ _settings . MODEL_DIR ,
165+ _settings . LABEL_FILE_NAME
166+ ) ;
161167
162168 if ( ! Directory . Exists ( rootDirectory ) )
163169 {
@@ -171,7 +177,10 @@ public NDArray GetTextEmbedding(string text)
171177
172178 foreach ( var filePath in GetFiles ( ) )
173179 {
174- var texts = File . ReadAllLines ( filePath , Encoding . UTF8 ) . Select ( x => TextClean ( x ) ) . ToList ( ) ;
180+ var texts = File . ReadAllLines ( filePath , Encoding . UTF8 )
181+ . Select ( x => TextClean ( x ) )
182+ . ToList ( ) ;
183+
175184 vectorList . AddRange ( vector . GetVectors ( texts ) ) ;
176185 string fileName = Path . GetFileNameWithoutExtension ( filePath ) ;
177186 labelList . AddRange ( Enumerable . Repeat ( fileName , texts . Count ) . ToList ( ) ) ;
@@ -187,16 +196,19 @@ public NDArray GetTextEmbedding(string text)
187196 for ( int i = 0 ; i < vectorList . Count ; i ++ )
188197 {
189198 x [ i ] = vectorList [ i ] ;
190- // y[i] = (float)uniqueLabelList.IndexOf(labelList[i]);
191199 y [ i ] = ( float ) Array . IndexOf ( uniqueLabelList , labelList [ i ] ) ;
192200 }
201+
193202 return ( x , y ) ;
194203 }
195204
196205 public string [ ] GetFiles ( string prefix = "intent" )
197206 {
198- var agentService = _services . CreateScope ( ) . ServiceProvider . GetRequiredService < IAgentService > ( ) ;
207+ var agentService = _services . CreateScope ( )
208+ . ServiceProvider
209+ . GetRequiredService < IAgentService > ( ) ;
199210 string rootDirectory = Path . Combine ( agentService . GetDataDir ( ) , _settings . RAW_DATA_DIR ) ;
211+
200212 return Directory . GetFiles ( rootDirectory )
201213 . Where ( x => Path . GetFileNameWithoutExtension ( x )
202214 . StartsWith ( prefix ) )
@@ -208,8 +220,15 @@ public string[] GetLabels()
208220 {
209221 if ( _labels == null )
210222 {
211- var agentService = _services . CreateScope ( ) . ServiceProvider . GetRequiredService < IAgentService > ( ) ;
212- string rootDirectory = Path . Combine ( agentService . GetDataDir ( ) , _settings . MODEL_DIR , _settings . LABEL_FILE_NAME ) ;
223+ var agentService = _services . CreateScope ( )
224+ . ServiceProvider
225+ . GetRequiredService < IAgentService > ( ) ;
226+ string rootDirectory = Path . Combine (
227+ agentService . GetDataDir ( ) ,
228+ _settings . MODEL_DIR ,
229+ _settings . LABEL_FILE_NAME
230+ ) ;
231+
213232 var labelText = File . ReadAllLines ( rootDirectory ) ;
214233 _labels = labelText . OrderBy ( x => x ) . ToArray ( ) ;
215234 }
@@ -223,9 +242,11 @@ public string TextClean(string text)
223242 // Remove digits
224243 // To lowercase
225244 var processedText = Regex . Replace ( text , "[AB0-9]" , " " ) ;
226- processedText = string . Join ( "" , processedText . Select ( c => char . IsPunctuation ( c ) ? ' ' : c ) . ToList ( ) ) ;
227- processedText = processedText . Replace ( " " , " " ) . ToLower ( ) ;
228- return processedText ;
245+ var replacedTextList = processedText . Select ( c => char . IsPunctuation ( c ) ? ' ' : c ) . ToList ( ) ;
246+
247+ return string . Join ( "" , replacedTextList )
248+ . Replace ( " " , " " )
249+ . ToLower ( ) ;
229250 }
230251
231252 public string Predict ( NDArray vector , float confidenceScore = 0.9f )
@@ -235,8 +256,8 @@ public string Predict(NDArray vector, float confidenceScore = 0.9f)
235256 InitClassifer ( ) ;
236257 }
237258
259+ // Generate and post-process prediction
238260 var prob = _model . predict ( vector ) . numpy ( ) ;
239-
240261 var probLabel = tf . arg_max ( prob , - 1 ) . numpy ( ) . ToArray < long > ( ) ;
241262 prob = np . squeeze ( prob , axis : 0 ) ;
242263
@@ -245,9 +266,9 @@ public string Predict(NDArray vector, float confidenceScore = 0.9f)
245266 return string . Empty ;
246267 }
247268
248- var prediction = _labels [ probLabel [ 0 ] ] ;
269+ var labelIndex = probLabel [ 0 ] ;
249270
250- return prediction ;
271+ return _labels [ labelIndex ] ;
251272 }
252273 public void InitClassifer ( bool inference = true )
253274 {
0 commit comments