@@ -30,6 +30,8 @@ This code implements the Convolutional Tsetlin Machine from paper arXiv:1905.096
3030# define __builtin_popcount __popcnt
3131#endif
3232
33+ #define MAX_PATCH_MATCH 10
34+
3335#include <stdio.h>
3436#include <stdlib.h>
3537#include <limits.h>
@@ -106,26 +108,43 @@ static inline void cb_dec(
106108}
107109
108110/* Calculate the output of each clause using the actions of each Tsetline Automaton. */
109- static inline void cb_calculate_clause_output_feedback (unsigned int * ta_state , unsigned int * output_one_patches , unsigned int * clause_output , unsigned int * clause_patch , int number_of_ta_chunks , int number_of_state_bits , unsigned int filter , int number_of_patches , unsigned int * literal_active , unsigned int * Xi )
111+ static inline void cb_calculate_clause_output_feedback (
112+ unsigned int * ta_state ,
113+ unsigned int * output_one_patches ,
114+ unsigned int * clause_output ,
115+ unsigned int * clause_patch ,
116+ int number_of_ta_chunks ,
117+ int number_of_state_bits ,
118+ unsigned int filter ,
119+ int number_of_patches ,
120+ unsigned int * literal_active ,
121+ unsigned int * patch_match_count ,
122+ unsigned int * Xi
123+ )
110124{
111125 int output_one_patches_count = 0 ;
112126 for (int patch = 0 ; patch < number_of_patches ; ++ patch ) {
113- unsigned int output = 1 ;
127+ if (patch_match_count [patch ] > MAX_PATCH_MATCH ) {
128+ continue ;
129+ }
130+
131+ unsigned int patch_output = 1 ;
114132 for (int k = 0 ; k < number_of_ta_chunks - 1 ; k ++ ) {
115133 unsigned int pos = k * number_of_state_bits + number_of_state_bits - 1 ;
116- output = output && (ta_state [pos ] & (Xi [patch * number_of_ta_chunks + k ] | (~literal_active [k ]))) == ta_state [pos ];
134+ patch_output = patch_output && (ta_state [pos ] & (Xi [patch * number_of_ta_chunks + k ] | (~literal_active [k ]))) == ta_state [pos ];
117135
118- if (!output ) {
136+ if (!patch_output ) {
119137 break ;
120138 }
121139 }
122140
123141 unsigned int pos = (number_of_ta_chunks - 1 )* number_of_state_bits + number_of_state_bits - 1 ;
124- output = output &&
142+ patch_output = patch_output &&
125143 (ta_state [pos ] & (Xi [patch * number_of_ta_chunks + number_of_ta_chunks - 1 ] | (~literal_active [number_of_ta_chunks - 1 ])) & filter ) ==
126144 (ta_state [pos ] & filter );
127145
128- if (output ) {
146+ if (patch_output ) {
147+ patch_match_count [patch ]++ ;
129148 output_one_patches [output_one_patches_count ] = patch ;
130149 output_one_patches_count ++ ;
131150 }
@@ -196,30 +215,45 @@ static inline int cb_calculate_clause_output_single_false_literal(unsigned int *
196215 }
197216}
198217
199- static inline unsigned int cb_calculate_clause_output_update (unsigned int * ta_state , int number_of_ta_chunks , int number_of_state_bits , unsigned int filter , int number_of_patches , unsigned int * literal_active , unsigned int * Xi )
218+ static inline unsigned int cb_calculate_clause_output_update (
219+ unsigned int * ta_state ,
220+ int number_of_ta_chunks ,
221+ int number_of_state_bits ,
222+ unsigned int filter ,
223+ int number_of_patches ,
224+ unsigned int * literal_active ,
225+ unsigned int * patch_match_count ,
226+ unsigned int * Xi
227+ )
200228{
229+ unsigned int clause_output = 0 ;
201230 for (int patch = 0 ; patch < number_of_patches ; ++ patch ) {
202- unsigned int output = 1 ;
231+ if (patch_match_count [patch ] > MAX_PATCH_MATCH ) {
232+ continue ;
233+ }
234+
235+ unsigned int patch_output = 1 ;
203236 for (int k = 0 ; k < number_of_ta_chunks - 1 ; k ++ ) {
204237 unsigned int pos = k * number_of_state_bits + number_of_state_bits - 1 ;
205- output = output && (ta_state [pos ] & (Xi [patch * number_of_ta_chunks + k ] | (~literal_active [k ]))) == ta_state [pos ];
238+ patch_output = patch_output && (ta_state [pos ] & (Xi [patch * number_of_ta_chunks + k ] | (~literal_active [k ]))) == ta_state [pos ];
206239
207- if (!output ) {
240+ if (!patch_output ) {
208241 break ;
209242 }
210243 }
211244
212245 unsigned int pos = (number_of_ta_chunks - 1 )* number_of_state_bits + number_of_state_bits - 1 ;
213- output = output &&
246+ patch_output = patch_output &&
214247 (ta_state [pos ] & (Xi [patch * number_of_ta_chunks + number_of_ta_chunks - 1 ] | (~literal_active [number_of_ta_chunks - 1 ])) & filter ) ==
215248 (ta_state [pos ] & filter );
216249
217- if (output ) {
218- return (1 );
250+ if (patch_output ) {
251+ patch_match_count [patch ]++ ;
252+ clause_output = 1 ;
219253 }
220254 }
221255
222- return (0 );
256+ return (clause_output );
223257}
224258
225259static inline void cb_calculate_clause_output_patchwise (unsigned int * ta_state , int number_of_ta_chunks , int number_of_state_bits , unsigned int filter , int number_of_patches , unsigned int * output , unsigned int * Xi )
@@ -250,38 +284,45 @@ static inline unsigned int cb_calculate_clause_output_predict(
250284 int number_of_state_bits ,
251285 unsigned int filter ,
252286 int number_of_patches ,
287+ unsigned int * patch_match_count ,
253288 unsigned int * Xi
254289)
255290{
291+ unsigned int clause_output = 0 ;
292+
256293 for (int patch = 0 ; patch < number_of_patches ; ++ patch ) {
257- unsigned int output = 1 ;
294+ if (patch_match_count [patch ] > MAX_PATCH_MATCH ) {
295+ continue ;
296+ }
297+
298+ unsigned int patch_output = 1 ;
258299 unsigned int all_exclude = 1 ;
259300 for (int k = 0 ; k < number_of_ta_chunks - 1 ; k ++ ) {
260301 unsigned int pos = k * number_of_state_bits + number_of_state_bits - 1 ;
261- output = output && (ta_state [pos ] & Xi [patch * number_of_ta_chunks + k ]) == ta_state [pos ];
302+ patch_output = patch_output && (ta_state [pos ] & Xi [patch * number_of_ta_chunks + k ]) == ta_state [pos ];
262303
263- if (!output ) {
304+ if (!patch_output ) {
264305 break ;
265306 }
266307 all_exclude = all_exclude && (ta_state [pos ] == 0 );
267308 }
268309
269310 unsigned int pos = (number_of_ta_chunks - 1 )* number_of_state_bits + number_of_state_bits - 1 ;
270- output = output &&
311+ patch_output = patch_output &&
271312 (ta_state [pos ] & Xi [patch * number_of_ta_chunks + number_of_ta_chunks - 1 ] & filter ) ==
272313 (ta_state [pos ] & filter );
273314
274315 all_exclude = all_exclude && ((ta_state [pos ] & filter ) == 0 );
275316
276- if (output && all_exclude == 0 ) {
277- return (1 );
317+ if (patch_output && all_exclude == 0 ) {
318+ clause_output = 1 ;
319+ patch_match_count [patch ]++ ;
278320 }
279321 }
280322
281- return (0 );
323+ return (clause_output );
282324}
283325
284-
285326void cb_type_i_feedback (
286327 unsigned int * ta_state ,
287328 unsigned int * feedback_to_ta ,
@@ -297,6 +338,7 @@ void cb_type_i_feedback(
297338 unsigned int max_included_literals ,
298339 unsigned int * clause_active ,
299340 unsigned int * literal_active ,
341+ unsigned int * patch_match_count ,
300342 unsigned int * Xi
301343)
302344{
@@ -313,6 +355,10 @@ void cb_type_i_feedback(
313355 cb_initialize_random_streams (feedback_to_ta , number_of_literals , number_of_ta_chunks , s );
314356 }
315357
358+ for (int patch = 0 ; patch < number_of_patches ; ++ patch ) {
359+ patch_match_count [patch ] = 0 ;
360+ }
361+
316362 for (int j = 0 ; j < number_of_clauses ; ++ j ) {
317363 if ((((float )fast_rand ())/((float )FAST_RAND_MAX ) > update_p ) || (!clause_active [j ])) {
318364 continue ;
@@ -323,7 +369,7 @@ void cb_type_i_feedback(
323369 unsigned int clause_output ;
324370 unsigned int clause_patch ;
325371
326- cb_calculate_clause_output_feedback (& ta_state [clause_pos ], output_one_patches , & clause_output , & clause_patch , number_of_ta_chunks , number_of_state_bits , filter , number_of_patches , literal_active , Xi );
372+ cb_calculate_clause_output_feedback (& ta_state [clause_pos ], output_one_patches , & clause_output , & clause_patch , number_of_ta_chunks , number_of_state_bits , filter , number_of_patches , literal_active , patch_match_count , Xi );
327373
328374 if (!reuse_random_feedback && s > 1.0 ) {
329375 cb_initialize_random_streams (feedback_to_ta , number_of_literals , number_of_ta_chunks , s );
@@ -372,6 +418,7 @@ void cb_type_ii_feedback(
372418 float update_p ,
373419 unsigned int * clause_active ,
374420 unsigned int * literal_active ,
421+ unsigned int * patch_match_count ,
375422 unsigned int * Xi
376423)
377424{
@@ -383,6 +430,10 @@ void cb_type_ii_feedback(
383430 }
384431 unsigned int number_of_ta_chunks = (number_of_literals - 1 )/32 + 1 ;
385432
433+ for (int patch = 0 ; patch < number_of_patches ; ++ patch ) {
434+ patch_match_count [patch ] = 0 ;
435+ }
436+
386437 for (int j = 0 ; j < number_of_clauses ; j ++ ) {
387438 if ((((float )fast_rand ())/((float )FAST_RAND_MAX ) > update_p ) || (!clause_active [j ])) {
388439 continue ;
@@ -392,7 +443,7 @@ void cb_type_ii_feedback(
392443
393444 unsigned int clause_output ;
394445 unsigned int clause_patch ;
395- cb_calculate_clause_output_feedback (& ta_state [clause_pos ], output_one_patches , & clause_output , & clause_patch , number_of_ta_chunks , number_of_state_bits , filter , number_of_patches , literal_active , Xi );
446+ cb_calculate_clause_output_feedback (& ta_state [clause_pos ], output_one_patches , & clause_output , & clause_patch , number_of_ta_chunks , number_of_state_bits , filter , number_of_patches , literal_active , patch_match_count , Xi );
396447
397448 if (clause_output ) {
398449 for (int k = 0 ; k < number_of_ta_chunks ; ++ k ) {
@@ -449,6 +500,7 @@ void cb_type_iii_feedback(
449500 filter ,
450501 number_of_patches ,
451502 literal_active ,
503+ NULL ,
452504 Xi
453505 );
454506
@@ -543,9 +595,13 @@ void cb_calculate_clause_outputs_predict(
543595 }
544596 unsigned int number_of_ta_chunks = (number_of_literals - 1 )/32 + 1 ;
545597
598+ for (int patch = 0 ; patch < number_of_patches ; ++ patch ) {
599+ patch_match_count [patch ] = 0 ;
600+ }
601+
546602 for (int j = 0 ; j < number_of_clauses ; j ++ ) {
547603 unsigned int clause_pos = j * number_of_ta_chunks * number_of_state_bits ;
548- clause_output [j ] = cb_calculate_clause_output_predict (& ta_state [clause_pos ], number_of_ta_chunks , number_of_state_bits , filter , number_of_patches , Xi );
604+ clause_output [j ] = cb_calculate_clause_output_predict (& ta_state [clause_pos ], number_of_ta_chunks , number_of_state_bits , filter , number_of_patches , patch_match_count , Xi );
549605 }
550606}
551607
@@ -741,6 +797,7 @@ void cb_calculate_clause_outputs_update(
741797 int number_of_patches ,
742798 unsigned int * clause_output ,
743799 unsigned int * literal_active ,
800+ unsigned int * patch_match_count ,
744801 unsigned int * Xi
745802)
746803{
@@ -753,9 +810,13 @@ void cb_calculate_clause_outputs_update(
753810
754811 unsigned int number_of_ta_chunks = (number_of_literals - 1 )/32 + 1 ;
755812
813+ for (int patch = 0 ; patch < number_of_patches ; ++ patch ) {
814+ patch_match_count [patch ] = 0 ;
815+ }
816+
756817 for (int j = 0 ; j < number_of_clauses ; j ++ ) {
757818 unsigned int clause_pos = j * number_of_ta_chunks * number_of_state_bits ;
758- clause_output [j ] = cb_calculate_clause_output_update (& ta_state [clause_pos ], number_of_ta_chunks , number_of_state_bits , filter , number_of_patches , literal_active , Xi );
819+ clause_output [j ] = cb_calculate_clause_output_update (& ta_state [clause_pos ], number_of_ta_chunks , number_of_state_bits , filter , number_of_patches , literal_active , patch_match_count , Xi );
759820 }
760821}
761822
0 commit comments