22using System . IO ;
33using System . Text ;
44using System . Collections . Generic ;
5- using Tensorflow ;
65using static Tensorflow . KerasApi ;
76using Tensorflow . Keras . Engine ;
87using Tensorflow . NumPy ;
1615using Microsoft . Extensions . DependencyInjection ;
1716using System . Linq ;
1817using Tensorflow . Keras ;
19- using System . Numerics ;
20- using Newtonsoft . Json ;
21- using Tensorflow . Keras . Layers ;
2218using BotSharp . Abstraction . Agents ;
23- using BotSharp . Abstraction . Knowledges ;
2419
2520namespace BotSharp . Plugin . RoutingSpeeder . Providers ;
2621
@@ -33,11 +28,8 @@ public class IntentClassifier
3328 private bool _isModelReady ;
3429 public bool isModelReady => _isModelReady ;
3530 private ClassifierSetting _settings ;
36-
3731 private string [ ] _labels ;
38-
3932 public string [ ] Labels => GetLabels ( ) ;
40-
4133 private int _numLabels
4234 {
4335 get
@@ -67,7 +59,7 @@ private void Build()
6759 }
6860
6961 var vector = _services . GetServices < ITextEmbedding > ( )
70- . FirstOrDefault ( x => x . GetType ( ) . FullName . EndsWith ( _knowledgeBaseSettings . TextEmbedding ) ) ;
62+ . FirstOrDefault ( x => x . GetType ( ) . FullName . EndsWith ( _knowledgeBaseSettings . TextEmbedding ) ) ;
7163
7264 var layers = new List < ILayer >
7365 {
@@ -89,28 +81,29 @@ private void Fit(NDArray x, NDArray y, TrainingParams trainingParams)
8981 {
9082 _model . compile ( optimizer : keras . optimizers . Adam ( trainingParams . LearningRate ) ,
9183 loss : keras . losses . SparseCategoricalCrossentropy ( ) ,
92- metrics : new [ ] { "accuracy" }
93- ) ;
84+ metrics : new [ ] { "accuracy" } ) ;
9485
95- CallbackParams callback_parameters = new CallbackParams
86+ var callback_parameters = new CallbackParams
9687 {
9788 Model = _model ,
9889 Epochs = trainingParams . Epochs ,
9990 Verbose = 1 ,
10091 Steps = 10
10192 } ;
10293
103- ICallback earlyStop = new EarlyStopping ( callback_parameters , "accuracy" ) ;
94+ var earlyStop = new EarlyStopping ( callback_parameters , "accuracy" ) ;
10495
105- var callbacks = new List < ICallback > ( ) { earlyStop } ;
96+ var callbacks = new List < ICallback > ( )
97+ {
98+ earlyStop
99+ } ;
106100
107101 var weights = LoadWeights ( trainingParams . Inference ) ;
108102
109103 _model . fit ( x , y ,
110104 batch_size : trainingParams . BatchSize ,
111105 epochs : trainingParams . Epochs ,
112106 callbacks : callbacks ,
113- // validation_split: 0.1f,
114107 shuffle : true ) ;
115108
116109 _model . save_weights ( weights ) ;
@@ -120,7 +113,9 @@ private void Fit(NDArray x, NDArray y, TrainingParams trainingParams)
120113
121114 public string LoadWeights ( bool inference = true )
122115 {
123- var agentService = _services . CreateScope ( ) . ServiceProvider . GetRequiredService < IAgentService > ( ) ;
116+ var agentService = _services . CreateScope ( )
117+ . ServiceProvider
118+ . GetRequiredService < IAgentService > ( ) ;
124119
125120 var weightsFile = Path . Combine ( agentService . GetDataDir ( ) , _settings . MODEL_DIR , $ "intent-classifier.h5") ;
126121
@@ -129,13 +124,13 @@ public string LoadWeights(bool inference = true)
129124 _model . load_weights ( weightsFile ) ;
130125 _isModelReady = true ;
131126 Console . WriteLine ( $ "Successfully load the weights!") ;
132-
133127 }
134128 else
135129 {
136130 var logInfo = inference ? "No available weights." : "Will implement model training process and write trained weights into local" ;
137131 Console . WriteLine ( logInfo ) ;
138132 }
133+
139134 return weightsFile ;
140135 }
141136
@@ -152,24 +147,33 @@ public NDArray GetTextEmbedding(string text)
152147
153148 public ( NDArray , NDArray ) PrepareLoadData ( )
154149 {
155- var agentService = _services . CreateScope ( ) . ServiceProvider . GetRequiredService < IAgentService > ( ) ;
156- string rootDirectory = Path . Combine ( agentService . GetDataDir ( ) , _settings . RAW_DATA_DIR ) ;
157- string saveLabelDirectory = Path . Combine ( agentService . GetDataDir ( ) , _settings . MODEL_DIR , _settings . LABEL_FILE_NAME ) ;
150+ var agentService = _services . CreateScope ( )
151+ . ServiceProvider
152+ . GetRequiredService < IAgentService > ( ) ;
153+ string rootDirectory = Path . Combine (
154+ agentService . GetDataDir ( ) ,
155+ _settings . RAW_DATA_DIR ) ;
156+ string saveLabelDirectory = Path . Combine (
157+ agentService . GetDataDir ( ) ,
158+ _settings . MODEL_DIR ,
159+ _settings . LABEL_FILE_NAME ) ;
158160
159161 if ( ! Directory . Exists ( rootDirectory ) )
160162 {
161163 throw new Exception ( $ "No training data found! Please put training data in this path: { rootDirectory } ") ;
162164 }
163165
166+ // Do embedding and store results
164167 var vector = _services . GetRequiredService < ITextEmbedding > ( ) ;
165-
166168 var vectorList = new List < float [ ] > ( ) ;
167-
168169 var labelList = new List < string > ( ) ;
169170
170171 foreach ( var filePath in GetFiles ( ) )
171172 {
172- var texts = File . ReadAllLines ( filePath , Encoding . UTF8 ) . Select ( x => TextClean ( x ) ) . ToList ( ) ;
173+ var texts = File . ReadAllLines ( filePath , Encoding . UTF8 )
174+ . Select ( x => TextClean ( x ) )
175+ . ToList ( ) ;
176+
173177 vectorList . AddRange ( vector . GetVectors ( texts ) ) ;
174178 string fileName = Path . GetFileNameWithoutExtension ( filePath ) ;
175179 labelList . AddRange ( Enumerable . Repeat ( fileName , texts . Count ) . ToList ( ) ) ;
@@ -185,25 +189,39 @@ public NDArray GetTextEmbedding(string text)
185189 for ( int i = 0 ; i < vectorList . Count ; i ++ )
186190 {
187191 x [ i ] = vectorList [ i ] ;
188- // y[i] = (float)uniqueLabelList.IndexOf(labelList[i]);
189192 y [ i ] = ( float ) Array . IndexOf ( uniqueLabelList , labelList [ i ] ) ;
190193 }
194+
191195 return ( x , y ) ;
192196 }
193197
194198 public string [ ] GetFiles ( string prefix = "intent" )
195199 {
196- var agentService = _services . CreateScope ( ) . ServiceProvider . GetRequiredService < IAgentService > ( ) ;
200+ var agentService = _services . CreateScope ( )
201+ . ServiceProvider
202+ . GetRequiredService < IAgentService > ( ) ;
197203 string rootDirectory = Path . Combine ( agentService . GetDataDir ( ) , _settings . RAW_DATA_DIR ) ;
198- return Directory . GetFiles ( rootDirectory ) . Where ( x => Path . GetFileNameWithoutExtension ( x ) . StartsWith ( prefix ) ) . OrderBy ( x => x ) . ToArray ( ) ;
204+
205+ return Directory . GetFiles ( rootDirectory )
206+ . Where ( x => Path . GetFileNameWithoutExtension ( x )
207+ . StartsWith ( prefix ) )
208+ . OrderBy ( x => x )
209+ . ToArray ( ) ;
199210 }
200211
201212 public string [ ] GetLabels ( )
202213 {
203214 if ( _labels == null )
204215 {
205- var agentService = _services . CreateScope ( ) . ServiceProvider . GetRequiredService < IAgentService > ( ) ;
206- string rootDirectory = Path . Combine ( agentService . GetDataDir ( ) , _settings . MODEL_DIR , _settings . LABEL_FILE_NAME ) ;
216+ var agentService = _services . CreateScope ( )
217+ . ServiceProvider
218+ . GetRequiredService < IAgentService > ( ) ;
219+ string rootDirectory = Path . Combine (
220+ agentService . GetDataDir ( ) ,
221+ _settings . MODEL_DIR ,
222+ _settings . LABEL_FILE_NAME
223+ ) ;
224+
207225 var labelText = File . ReadAllLines ( rootDirectory ) ;
208226 _labels = labelText . OrderBy ( x => x ) . ToArray ( ) ;
209227 }
@@ -217,9 +235,11 @@ public string TextClean(string text)
217235 // Remove digits
218236 // To lowercase
219237 var processedText = Regex . Replace ( text , "[AB0-9]" , " " ) ;
220- processedText = string . Join ( "" , processedText . Select ( c => char . IsPunctuation ( c ) ? ' ' : c ) . ToList ( ) ) ;
221- processedText = processedText . Replace ( " " , " " ) . ToLower ( ) ;
222- return processedText ;
238+ var replacedTextList = processedText . Select ( c => char . IsPunctuation ( c ) ? ' ' : c ) . ToList ( ) ;
239+
240+ return string . Join ( "" , replacedTextList )
241+ . Replace ( " " , " " )
242+ . ToLower ( ) ;
223243 }
224244
225245 public string Predict ( NDArray vector , float confidenceScore = 0.9f )
@@ -229,8 +249,8 @@ public string Predict(NDArray vector, float confidenceScore = 0.9f)
229249 InitClassifer ( ) ;
230250 }
231251
252+ // Generate and post-process prediction
232253 var prob = _model . predict ( vector ) . numpy ( ) ;
233-
234254 var probLabel = tf . arg_max ( prob , - 1 ) . numpy ( ) . ToArray < long > ( ) ;
235255 prob = np . squeeze ( prob , axis : 0 ) ;
236256
@@ -239,9 +259,9 @@ public string Predict(NDArray vector, float confidenceScore = 0.9f)
239259 return string . Empty ;
240260 }
241261
242- var prediction = _labels [ probLabel [ 0 ] ] ;
262+ var labelIndex = probLabel [ 0 ] ;
243263
244- return prediction ;
264+ return _labels [ labelIndex ] ;
245265 }
246266 public void InitClassifer ( bool inference = true )
247267 {
0 commit comments