Skip to content

Commit 0e77c03

Browse files
committed
Update
1 parent 6fba1dd commit 0e77c03

File tree

3 files changed

+94
-27
lines changed

3 files changed

+94
-27
lines changed

tmu/clause_bank/clause_bank.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def initialize_clauses(self):
157157
def calculate_clause_outputs_predict(self, encoded_X, e):
158158
xi_p = ffi.cast("unsigned int *", encoded_X[e, :].ctypes.data)
159159

160-
if not self.incremental:
160+
if True or not self.incremental:
161161
lib.cb_calculate_clause_outputs_predict(
162162
self.ptr_ta_state,
163163
self.number_of_clauses,
@@ -215,6 +215,7 @@ def calculate_clause_outputs_update(self, literal_active, encoded_X, e):
215215
self.number_of_patches,
216216
self.co_p,
217217
la_p,
218+
self.ptr_patch_match_count,
218219
xi_p
219220
)
220221

@@ -261,6 +262,7 @@ def type_i_feedback(
261262
self.max_included_literals,
262263
ptr_clause_active,
263264
ptr_literal_active,
265+
self.ptr_patch_match_count,
264266
ptr_xi
265267
)
266268

@@ -288,6 +290,7 @@ def type_ii_feedback(
288290
update_p,
289291
ptr_clause_active,
290292
ptr_literal_active,
293+
self.ptr_patch_match_count,
291294
ptr_xi
292295
)
293296

tmu/lib/include/ClauseBank.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ void cb_type_i_feedback(
4040
unsigned int max_included_literals,
4141
unsigned int *clause_active,
4242
unsigned int *literal_active,
43+
unsigned int *patch_match_count,
4344
unsigned int *Xi
4445
);
4546

@@ -53,6 +54,7 @@ void cb_type_ii_feedback(
5354
float update_p,
5455
unsigned int *clause_active,
5556
unsigned int *literal_active,
57+
unsigned int *patch_match_count,
5658
unsigned int *Xi
5759
);
5860

@@ -93,6 +95,7 @@ void cb_calculate_clause_outputs_update(
9395
int number_of_patches,
9496
unsigned int *clause_output,
9597
unsigned int *literal_active,
98+
unsigned int *patch_match_count,
9699
unsigned int *Xi
97100
);
98101

tmu/lib/src/ClauseBank.c

Lines changed: 87 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ This code implements the Convolutional Tsetlin Machine from paper arXiv:1905.096
3030
# define __builtin_popcount __popcnt
3131
#endif
3232

33+
#define MAX_PATCH_MATCH 10
34+
3335
#include <stdio.h>
3436
#include <stdlib.h>
3537
#include <limits.h>
@@ -106,26 +108,43 @@ static inline void cb_dec(
106108
}
107109

108110
/* Calculate the output of each clause using the actions of each Tsetline Automaton. */
109-
static inline void cb_calculate_clause_output_feedback(unsigned int *ta_state, unsigned int *output_one_patches, unsigned int *clause_output, unsigned int *clause_patch, int number_of_ta_chunks, int number_of_state_bits, unsigned int filter, int number_of_patches, unsigned int *literal_active, unsigned int *Xi)
111+
static inline void cb_calculate_clause_output_feedback(
112+
unsigned int *ta_state,
113+
unsigned int *output_one_patches,
114+
unsigned int *clause_output,
115+
unsigned int *clause_patch,
116+
int number_of_ta_chunks,
117+
int number_of_state_bits,
118+
unsigned int filter,
119+
int number_of_patches,
120+
unsigned int *literal_active,
121+
unsigned int *patch_match_count,
122+
unsigned int *Xi
123+
)
110124
{
111125
int output_one_patches_count = 0;
112126
for (int patch = 0; patch < number_of_patches; ++patch) {
113-
unsigned int output = 1;
127+
if (patch_match_count[patch] > MAX_PATCH_MATCH) {
128+
continue;
129+
}
130+
131+
unsigned int patch_output = 1;
114132
for (int k = 0; k < number_of_ta_chunks-1; k++) {
115133
unsigned int pos = k*number_of_state_bits + number_of_state_bits-1;
116-
output = output && (ta_state[pos] & (Xi[patch*number_of_ta_chunks + k] | (~literal_active[k]))) == ta_state[pos];
134+
patch_output = patch_output && (ta_state[pos] & (Xi[patch*number_of_ta_chunks + k] | (~literal_active[k]))) == ta_state[pos];
117135

118-
if (!output) {
136+
if (!patch_output) {
119137
break;
120138
}
121139
}
122140

123141
unsigned int pos = (number_of_ta_chunks-1)*number_of_state_bits + number_of_state_bits-1;
124-
output = output &&
142+
patch_output = patch_output &&
125143
(ta_state[pos] & (Xi[patch*number_of_ta_chunks + number_of_ta_chunks - 1] | (~literal_active[number_of_ta_chunks - 1])) & filter) ==
126144
(ta_state[pos] & filter);
127145

128-
if (output) {
146+
if (patch_output) {
147+
patch_match_count[patch]++;
129148
output_one_patches[output_one_patches_count] = patch;
130149
output_one_patches_count++;
131150
}
@@ -196,30 +215,45 @@ static inline int cb_calculate_clause_output_single_false_literal(unsigned int *
196215
}
197216
}
198217

199-
static inline unsigned int cb_calculate_clause_output_update(unsigned int *ta_state, int number_of_ta_chunks, int number_of_state_bits, unsigned int filter, int number_of_patches, unsigned int *literal_active, unsigned int *Xi)
218+
static inline unsigned int cb_calculate_clause_output_update(
219+
unsigned int *ta_state,
220+
int number_of_ta_chunks,
221+
int number_of_state_bits,
222+
unsigned int filter,
223+
int number_of_patches,
224+
unsigned int *literal_active,
225+
unsigned int *patch_match_count,
226+
unsigned int *Xi
227+
)
200228
{
229+
unsigned int clause_output = 0;
201230
for (int patch = 0; patch < number_of_patches; ++patch) {
202-
unsigned int output = 1;
231+
if (patch_match_count[patch] > MAX_PATCH_MATCH) {
232+
continue;
233+
}
234+
235+
unsigned int patch_output = 1;
203236
for (int k = 0; k < number_of_ta_chunks-1; k++) {
204237
unsigned int pos = k*number_of_state_bits + number_of_state_bits-1;
205-
output = output && (ta_state[pos] & (Xi[patch*number_of_ta_chunks + k] | (~literal_active[k]))) == ta_state[pos];
238+
patch_output = patch_output && (ta_state[pos] & (Xi[patch*number_of_ta_chunks + k] | (~literal_active[k]))) == ta_state[pos];
206239

207-
if (!output) {
240+
if (!patch_output) {
208241
break;
209242
}
210243
}
211244

212245
unsigned int pos = (number_of_ta_chunks-1)*number_of_state_bits + number_of_state_bits-1;
213-
output = output &&
246+
patch_output = patch_output &&
214247
(ta_state[pos] & (Xi[patch*number_of_ta_chunks + number_of_ta_chunks - 1] | (~literal_active[number_of_ta_chunks - 1])) & filter) ==
215248
(ta_state[pos] & filter);
216249

217-
if (output) {
218-
return(1);
250+
if (patch_output) {
251+
patch_match_count[patch]++;
252+
clause_output = 1;
219253
}
220254
}
221255

222-
return(0);
256+
return(clause_output);
223257
}
224258

225259
static inline void cb_calculate_clause_output_patchwise(unsigned int *ta_state, int number_of_ta_chunks, int number_of_state_bits, unsigned int filter, int number_of_patches, unsigned int *output, unsigned int *Xi)
@@ -250,38 +284,45 @@ static inline unsigned int cb_calculate_clause_output_predict(
250284
int number_of_state_bits,
251285
unsigned int filter,
252286
int number_of_patches,
287+
unsigned int *patch_match_count,
253288
unsigned int *Xi
254289
)
255290
{
291+
unsigned int clause_output = 0;
292+
256293
for (int patch = 0; patch < number_of_patches; ++patch) {
257-
unsigned int output = 1;
294+
if (patch_match_count[patch] > MAX_PATCH_MATCH) {
295+
continue;
296+
}
297+
298+
unsigned int patch_output = 1;
258299
unsigned int all_exclude = 1;
259300
for (int k = 0; k < number_of_ta_chunks-1; k++) {
260301
unsigned int pos = k*number_of_state_bits + number_of_state_bits-1;
261-
output = output && (ta_state[pos] & Xi[patch*number_of_ta_chunks + k]) == ta_state[pos];
302+
patch_output = patch_output && (ta_state[pos] & Xi[patch*number_of_ta_chunks + k]) == ta_state[pos];
262303

263-
if (!output) {
304+
if (!patch_output) {
264305
break;
265306
}
266307
all_exclude = all_exclude && (ta_state[pos] == 0);
267308
}
268309

269310
unsigned int pos = (number_of_ta_chunks-1)*number_of_state_bits + number_of_state_bits-1;
270-
output = output &&
311+
patch_output = patch_output &&
271312
(ta_state[pos] & Xi[patch*number_of_ta_chunks + number_of_ta_chunks - 1] & filter) ==
272313
(ta_state[pos] & filter);
273314

274315
all_exclude = all_exclude && ((ta_state[pos] & filter) == 0);
275316

276-
if (output && all_exclude == 0) {
277-
return(1);
317+
if (patch_output && all_exclude == 0) {
318+
clause_output = 1;
319+
patch_match_count[patch]++;
278320
}
279321
}
280322

281-
return(0);
323+
return(clause_output);
282324
}
283325

284-
285326
void cb_type_i_feedback(
286327
unsigned int *ta_state,
287328
unsigned int *feedback_to_ta,
@@ -297,6 +338,7 @@ void cb_type_i_feedback(
297338
unsigned int max_included_literals,
298339
unsigned int *clause_active,
299340
unsigned int *literal_active,
341+
unsigned int *patch_match_count,
300342
unsigned int *Xi
301343
)
302344
{
@@ -313,6 +355,10 @@ void cb_type_i_feedback(
313355
cb_initialize_random_streams(feedback_to_ta, number_of_literals, number_of_ta_chunks, s);
314356
}
315357

358+
for (int patch = 0; patch < number_of_patches; ++patch) {
359+
patch_match_count[patch] = 0;
360+
}
361+
316362
for (int j = 0; j < number_of_clauses; ++j) {
317363
if ((((float)fast_rand())/((float)FAST_RAND_MAX) > update_p) || (!clause_active[j])) {
318364
continue;
@@ -323,7 +369,7 @@ void cb_type_i_feedback(
323369
unsigned int clause_output;
324370
unsigned int clause_patch;
325371

326-
cb_calculate_clause_output_feedback(&ta_state[clause_pos], output_one_patches, &clause_output, &clause_patch, number_of_ta_chunks, number_of_state_bits, filter, number_of_patches, literal_active, Xi);
372+
cb_calculate_clause_output_feedback(&ta_state[clause_pos], output_one_patches, &clause_output, &clause_patch, number_of_ta_chunks, number_of_state_bits, filter, number_of_patches, literal_active, patch_match_count, Xi);
327373

328374
if (!reuse_random_feedback && s > 1.0) {
329375
cb_initialize_random_streams(feedback_to_ta, number_of_literals, number_of_ta_chunks, s);
@@ -372,6 +418,7 @@ void cb_type_ii_feedback(
372418
float update_p,
373419
unsigned int *clause_active,
374420
unsigned int *literal_active,
421+
unsigned int *patch_match_count,
375422
unsigned int *Xi
376423
)
377424
{
@@ -383,6 +430,10 @@ void cb_type_ii_feedback(
383430
}
384431
unsigned int number_of_ta_chunks = (number_of_literals-1)/32 + 1;
385432

433+
for (int patch = 0; patch < number_of_patches; ++patch) {
434+
patch_match_count[patch] = 0;
435+
}
436+
386437
for (int j = 0; j < number_of_clauses; j++) {
387438
if ((((float)fast_rand())/((float)FAST_RAND_MAX) > update_p) || (!clause_active[j])) {
388439
continue;
@@ -392,7 +443,7 @@ void cb_type_ii_feedback(
392443

393444
unsigned int clause_output;
394445
unsigned int clause_patch;
395-
cb_calculate_clause_output_feedback(&ta_state[clause_pos], output_one_patches, &clause_output, &clause_patch, number_of_ta_chunks, number_of_state_bits, filter, number_of_patches, literal_active, Xi);
446+
cb_calculate_clause_output_feedback(&ta_state[clause_pos], output_one_patches, &clause_output, &clause_patch, number_of_ta_chunks, number_of_state_bits, filter, number_of_patches, literal_active, patch_match_count, Xi);
396447

397448
if (clause_output) {
398449
for (int k = 0; k < number_of_ta_chunks; ++k) {
@@ -449,6 +500,7 @@ void cb_type_iii_feedback(
449500
filter,
450501
number_of_patches,
451502
literal_active,
503+
NULL,
452504
Xi
453505
);
454506

@@ -543,9 +595,13 @@ void cb_calculate_clause_outputs_predict(
543595
}
544596
unsigned int number_of_ta_chunks = (number_of_literals-1)/32 + 1;
545597

598+
for (int patch = 0; patch < number_of_patches; ++patch) {
599+
patch_match_count[patch] = 0;
600+
}
601+
546602
for (int j = 0; j < number_of_clauses; j++) {
547603
unsigned int clause_pos = j*number_of_ta_chunks*number_of_state_bits;
548-
clause_output[j] = cb_calculate_clause_output_predict(&ta_state[clause_pos], number_of_ta_chunks, number_of_state_bits, filter, number_of_patches, Xi);
604+
clause_output[j] = cb_calculate_clause_output_predict(&ta_state[clause_pos], number_of_ta_chunks, number_of_state_bits, filter, number_of_patches, patch_match_count, Xi);
549605
}
550606
}
551607

@@ -741,6 +797,7 @@ void cb_calculate_clause_outputs_update(
741797
int number_of_patches,
742798
unsigned int *clause_output,
743799
unsigned int *literal_active,
800+
unsigned int *patch_match_count,
744801
unsigned int *Xi
745802
)
746803
{
@@ -753,9 +810,13 @@ void cb_calculate_clause_outputs_update(
753810

754811
unsigned int number_of_ta_chunks = (number_of_literals-1)/32 + 1;
755812

813+
for (int patch = 0; patch < number_of_patches; ++patch) {
814+
patch_match_count[patch] = 0;
815+
}
816+
756817
for (int j = 0; j < number_of_clauses; j++) {
757818
unsigned int clause_pos = j*number_of_ta_chunks*number_of_state_bits;
758-
clause_output[j] = cb_calculate_clause_output_update(&ta_state[clause_pos], number_of_ta_chunks, number_of_state_bits, filter, number_of_patches, literal_active, Xi);
819+
clause_output[j] = cb_calculate_clause_output_update(&ta_state[clause_pos], number_of_ta_chunks, number_of_state_bits, filter, number_of_patches, literal_active, patch_match_count, Xi);
759820
}
760821
}
761822

0 commit comments

Comments
 (0)