diff --git a/MultiClassTsetlinMachine.c b/MultiClassTsetlinMachine.c index 2357312..15ca6ce 100644 --- a/MultiClassTsetlinMachine.c +++ b/MultiClassTsetlinMachine.c @@ -73,7 +73,7 @@ float mc_tm_evaluate(struct MultiClassTsetlinMachine *mc_tm, int X[][FEATURES], max_class_sum = tm_score(mc_tm->tsetlin_machines[0], X[l]); max_class = 0; - for (int i = 1; i < CLASSES; i++) { + for (int i = 1; i < CLASSES; i++) { int class_sum = tm_score(mc_tm->tsetlin_machines[i], X[l]); if (max_class_sum < class_sum) { max_class_sum = class_sum; @@ -85,7 +85,7 @@ float mc_tm_evaluate(struct MultiClassTsetlinMachine *mc_tm, int X[][FEATURES], errors += 1; } } - + return 1.0 - 1.0 * errors / number_of_examples; } @@ -100,7 +100,7 @@ void mc_tm_update(struct MultiClassTsetlinMachine *mc_tm, int Xi[], int target_c { tm_update(mc_tm->tsetlin_machines[target_class], Xi, 1, s); - // Randomly pick one of the other classes, for pairwise learning of class output + // Randomly pick one of the other classes, for pairwise learning of class output int negative_target_class = (int)CLASSES * 1.0*rand()/RAND_MAX; while (negative_target_class == target_class) { negative_target_class = (int)CLASSES * 1.0*rand()/RAND_MAX; @@ -116,7 +116,7 @@ void mc_tm_update(struct MultiClassTsetlinMachine *mc_tm, int Xi[], int target_c void mc_tm_fit(struct MultiClassTsetlinMachine *mc_tm, int X[][FEATURES], int y[], int number_of_examples, int epochs, float s) { for (int epoch = 0; epoch < epochs; epoch++) { - // Add shuffling here... + // Add shuffling here... for (int i = 0; i < number_of_examples; i++) { mc_tm_update(mc_tm, X[i], y[i], s); } diff --git a/NoisyXORDemo.c b/NoisyXORDemo.c index 81d4e49..1525518 100644 --- a/NoisyXORDemo.c +++ b/NoisyXORDemo.c @@ -59,7 +59,7 @@ void read_file(void) int main(void) -{ +{ srand(time(NULL)); read_file(); diff --git a/TsetlinMachine.c b/TsetlinMachine.c index fe17f82..2aae807 100644 --- a/TsetlinMachine.c +++ b/TsetlinMachine.c @@ -41,17 +41,7 @@ struct TsetlinMachine *CreateTsetlinMachine() /* Set up the Tsetlin Machine structure */ - for (int j = 0; j < CLAUSES; j++) { - for (int k = 0; k < FEATURES; k++) { - if (1.0 * rand()/RAND_MAX <= 0.5) { - (*tm).ta_state[j][k][0] = NUMBER_OF_STATES; - (*tm).ta_state[j][k][1] = NUMBER_OF_STATES + 1; - } else { - (*tm).ta_state[j][k][0] = NUMBER_OF_STATES + 1; - (*tm).ta_state[j][k][1] = NUMBER_OF_STATES; - } - } - } + if(tm) tm_initialize(tm); return tm; } @@ -59,14 +49,16 @@ struct TsetlinMachine *CreateTsetlinMachine() void tm_initialize(struct TsetlinMachine *tm) { - for (int j = 0; j < CLAUSES; j++) { + for (int j = 0; j < CLAUSES; j++) { + TsetlinMachineFeatures_t *clauses = tm->ta_state[j]; for (int k = 0; k < FEATURES; k++) { + int *features = clauses[k]; if (1.0 * rand()/RAND_MAX <= 0.5) { - (*tm).ta_state[j][k][0] = NUMBER_OF_STATES; - (*tm).ta_state[j][k][1] = NUMBER_OF_STATES + 1; + features[0] = NUMBER_OF_STATES; + features[1] = NUMBER_OF_STATES + 1; } else { - (*tm).ta_state[j][k][0] = NUMBER_OF_STATES + 1; - (*tm).ta_state[j][k][1] = NUMBER_OF_STATES; // Deviation, should be random + features[0] = NUMBER_OF_STATES + 1; + features[1] = NUMBER_OF_STATES; // Deviation, should be random } } } @@ -88,21 +80,23 @@ static inline void calculate_clause_output(struct TsetlinMachine *tm, int Xi[], int all_exclude; for (j = 0; j < CLAUSES; j++) { - (*tm).clause_output[j] = 1; + tm->clause_output[j] = 1; all_exclude = 1; + TsetlinMachineFeatures_t *clauses = tm->ta_state[j]; for (k = 0; k < FEATURES; k++) { - action_include = action((*tm).ta_state[j][k][0]); - action_include_negated = action((*tm).ta_state[j][k][1]); + int *features = clauses[k]; + action_include = action(features[0]); + action_include_negated = action(features[1]); all_exclude = all_exclude && !(action_include == 1 || action_include_negated == 1); if ((action_include == 1 && Xi[k] == 0) || (action_include_negated == 1 && Xi[k] == 1)) { - (*tm).clause_output[j] = 0; + tm->clause_output[j] = 0; break; } } - (*tm).clause_output[j] = (*tm).clause_output[j] && !(predict == PREDICT && all_exclude == 1); + tm->clause_output[j] = tm->clause_output[j] && !(predict == PREDICT && all_exclude == 1); } } @@ -112,9 +106,9 @@ static inline int sum_up_class_votes(struct TsetlinMachine *tm) int class_sum = 0; for (int j = 0; j < CLAUSES; j++) { int sign = 1 - 2 * (j & 1); - class_sum += (*tm).clause_output[j]*sign; + class_sum += tm->clause_output[j]*sign; } - + class_sum = (class_sum > THRESHOLD) ? THRESHOLD : class_sum; class_sum = (class_sum < -THRESHOLD) ? -THRESHOLD : class_sum; @@ -133,22 +127,24 @@ int tm_get_state(struct TsetlinMachine *tm, int clause, int feature, int automat static inline void type_i_feedback(struct TsetlinMachine *tm, int Xi[], int j, float s) { - if ((*tm).clause_output[j] == 0) { + if (tm->clause_output[j] == 0) { for (int k = 0; k < FEATURES; k++) { - (*tm).ta_state[j][k][0] -= ((*tm).ta_state[j][k][0] > 1) && (1.0*rand()/RAND_MAX <= 1.0/s); - - (*tm).ta_state[j][k][1] -= ((*tm).ta_state[j][k][1] > 1) && (1.0*rand()/RAND_MAX <= 1.0/s); + int *features = tm->ta_state[j][k]; + features[0] -= (features[0] > 1) && (1.0*rand()/RAND_MAX <= 1.0/s); + + features[1] -= (features[1] > 1) && (1.0*rand()/RAND_MAX <= 1.0/s); } - } else if ((*tm).clause_output[j] == 1) { + } else if (tm->clause_output[j] == 1) { for (int k = 0; k < FEATURES; k++) { + int *features = tm->ta_state[j][k]; if (Xi[k] == 1) { - (*tm).ta_state[j][k][0] += ((*tm).ta_state[j][k][0] < NUMBER_OF_STATES*2) && (BOOST_TRUE_POSITIVE_FEEDBACK == 1 || 1.0*rand()/RAND_MAX <= (s-1)/s); + features[0] += (features[0] < NUMBER_OF_STATES*2) && (BOOST_TRUE_POSITIVE_FEEDBACK == 1 || 1.0*rand()/RAND_MAX <= (s-1)/s); - (*tm).ta_state[j][k][1] -= ((*tm).ta_state[j][k][1] > 1) && (1.0*rand()/RAND_MAX <= 1.0/s); + features[1] -= (features[1] > 1) && (1.0*rand()/RAND_MAX <= 1.0/s); } else if (Xi[k] == 0) { - (*tm).ta_state[j][k][1] += ((*tm).ta_state[j][k][1] < NUMBER_OF_STATES*2) && (BOOST_TRUE_POSITIVE_FEEDBACK == 1 || 1.0*rand()/RAND_MAX <= (s-1)/s); - - (*tm).ta_state[j][k][0] -= ((*tm).ta_state[j][k][0] > 1) && (1.0*rand()/RAND_MAX <= 1.0/s); + features[1] += (features[1] < NUMBER_OF_STATES*2) && (BOOST_TRUE_POSITIVE_FEEDBACK == 1 || 1.0*rand()/RAND_MAX <= (s-1)/s); + + features[0] -= (features[0] > 1) && (1.0*rand()/RAND_MAX <= 1.0/s); } } } @@ -163,13 +159,14 @@ static inline void type_ii_feedback(struct TsetlinMachine *tm, int Xi[], int j) int action_include; int action_include_negated; - if ((*tm).clause_output[j] == 1) { - for (int k = 0; k < FEATURES; k++) { - action_include = action((*tm).ta_state[j][k][0]); - action_include_negated = action((*tm).ta_state[j][k][1]); + if (tm->clause_output[j] == 1) { + for (int k = 0; k < FEATURES; k++) { + int *features = tm->ta_state[j][k]; + action_include = action(features[0]); + action_include_negated = action(features[1]); - (*tm).ta_state[j][k][0] += (action_include == 0 && (*tm).ta_state[j][k][0] < NUMBER_OF_STATES*2) && (Xi[k] == 0); - (*tm).ta_state[j][k][1] += (action_include_negated == 0 && (*tm).ta_state[j][k][1] < NUMBER_OF_STATES*2) && (Xi[k] == 1); + features[0] += (action_include == 0 && features[0] < NUMBER_OF_STATES*2) && (Xi[k] == 0); + features[1] += (action_include_negated == 0 && features[1] < NUMBER_OF_STATES*2) && (Xi[k] == 1); } } } @@ -200,17 +197,17 @@ void tm_update(struct TsetlinMachine *tm, int Xi[], int target, float s) { // Calculate feedback to clauses for (int j = 0; j < CLAUSES; j++) { - (*tm).feedback_to_clauses[j] = (2*target-1)*(1 - 2 * (j & 1))*(1.0*rand()/RAND_MAX <= (1.0/(THRESHOLD*2))*(THRESHOLD + (1 - 2*target)*class_sum)); + tm->feedback_to_clauses[j] = (2*target-1)*(1 - 2 * (j & 1))*(1.0*rand()/RAND_MAX <= (1.0/(THRESHOLD*2))*(THRESHOLD + (1 - 2*target)*class_sum)); } - + /*********************************/ /*** Train Individual Automata ***/ /*********************************/ for (int j = 0; j < CLAUSES; j++) { - if ((*tm).feedback_to_clauses[j] > 0) { + if (tm->feedback_to_clauses[j] > 0) { type_i_feedback(tm, Xi, j, s); - } else if ((*tm).feedback_to_clauses[j] < 0) { + } else if (tm->feedback_to_clauses[j] < 0) { type_ii_feedback(tm, Xi, j); } } diff --git a/TsetlinMachine.h b/TsetlinMachine.h index 64bdbdd..5cdabc4 100644 --- a/TsetlinMachine.h +++ b/TsetlinMachine.h @@ -34,8 +34,13 @@ This code implements the Tsetlin Machine from paper arXiv:1804.01508 #define PREDICT 1 #define UPDATE 0 -struct TsetlinMachine { - int ta_state[CLAUSES][FEATURES][2]; +typedef int TsetlinMachineFeatures_t[2]; +typedef TsetlinMachineFeatures_t TsetlinMachineClauses_t[FEATURES]; +typedef TsetlinMachineClauses_t TsetlinMachineState_t[CLAUSES]; + +struct TsetlinMachine { + /*int ta_state[CLAUSES][FEATURES][2]*/ + TsetlinMachineState_t ta_state; int clause_output[CLAUSES];