Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions MultiClassTsetlinMachine.c
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ float mc_tm_evaluate(struct MultiClassTsetlinMachine *mc_tm, int X[][FEATURES],

max_class_sum = tm_score(mc_tm->tsetlin_machines[0], X[l]);
max_class = 0;
for (int i = 1; i < CLASSES; i++) {
for (int i = 1; i < CLASSES; i++) {
int class_sum = tm_score(mc_tm->tsetlin_machines[i], X[l]);
if (max_class_sum < class_sum) {
max_class_sum = class_sum;
Expand All @@ -85,7 +85,7 @@ float mc_tm_evaluate(struct MultiClassTsetlinMachine *mc_tm, int X[][FEATURES],
errors += 1;
}
}

return 1.0 - 1.0 * errors / number_of_examples;
}

Expand All @@ -100,7 +100,7 @@ void mc_tm_update(struct MultiClassTsetlinMachine *mc_tm, int Xi[], int target_c
{
tm_update(mc_tm->tsetlin_machines[target_class], Xi, 1, s);

// Randomly pick one of the other classes, for pairwise learning of class output
// Randomly pick one of the other classes, for pairwise learning of class output
int negative_target_class = (int)CLASSES * 1.0*rand()/RAND_MAX;
while (negative_target_class == target_class) {
negative_target_class = (int)CLASSES * 1.0*rand()/RAND_MAX;
Expand All @@ -116,7 +116,7 @@ void mc_tm_update(struct MultiClassTsetlinMachine *mc_tm, int Xi[], int target_c
void mc_tm_fit(struct MultiClassTsetlinMachine *mc_tm, int X[][FEATURES], int y[], int number_of_examples, int epochs, float s)
{
for (int epoch = 0; epoch < epochs; epoch++) {
// Add shuffling here...
// Add shuffling here...
for (int i = 0; i < number_of_examples; i++) {
mc_tm_update(mc_tm, X[i], y[i], s);
}
Expand Down
2 changes: 1 addition & 1 deletion NoisyXORDemo.c
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ void read_file(void)


int main(void)
{
{
srand(time(NULL));

read_file();
Expand Down
83 changes: 40 additions & 43 deletions TsetlinMachine.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,32 +41,24 @@ struct TsetlinMachine *CreateTsetlinMachine()

/* Set up the Tsetlin Machine structure */

for (int j = 0; j < CLAUSES; j++) {
for (int k = 0; k < FEATURES; k++) {
if (1.0 * rand()/RAND_MAX <= 0.5) {
(*tm).ta_state[j][k][0] = NUMBER_OF_STATES;
(*tm).ta_state[j][k][1] = NUMBER_OF_STATES + 1;
} else {
(*tm).ta_state[j][k][0] = NUMBER_OF_STATES + 1;
(*tm).ta_state[j][k][1] = NUMBER_OF_STATES;
}
}
}
if(tm) tm_initialize(tm);

return tm;
}


void tm_initialize(struct TsetlinMachine *tm)
{
for (int j = 0; j < CLAUSES; j++) {
for (int j = 0; j < CLAUSES; j++) {
TsetlinMachineFeatures_t *clauses = tm->ta_state[j];
for (int k = 0; k < FEATURES; k++) {
int *features = clauses[k];
if (1.0 * rand()/RAND_MAX <= 0.5) {
(*tm).ta_state[j][k][0] = NUMBER_OF_STATES;
(*tm).ta_state[j][k][1] = NUMBER_OF_STATES + 1;
features[0] = NUMBER_OF_STATES;
features[1] = NUMBER_OF_STATES + 1;
} else {
(*tm).ta_state[j][k][0] = NUMBER_OF_STATES + 1;
(*tm).ta_state[j][k][1] = NUMBER_OF_STATES; // Deviation, should be random
features[0] = NUMBER_OF_STATES + 1;
features[1] = NUMBER_OF_STATES; // Deviation, should be random
}
}
}
Expand All @@ -88,21 +80,23 @@ static inline void calculate_clause_output(struct TsetlinMachine *tm, int Xi[],
int all_exclude;

for (j = 0; j < CLAUSES; j++) {
(*tm).clause_output[j] = 1;
tm->clause_output[j] = 1;
all_exclude = 1;
TsetlinMachineFeatures_t *clauses = tm->ta_state[j];
for (k = 0; k < FEATURES; k++) {
action_include = action((*tm).ta_state[j][k][0]);
action_include_negated = action((*tm).ta_state[j][k][1]);
int *features = clauses[k];
action_include = action(features[0]);
action_include_negated = action(features[1]);

all_exclude = all_exclude && !(action_include == 1 || action_include_negated == 1);

if ((action_include == 1 && Xi[k] == 0) || (action_include_negated == 1 && Xi[k] == 1)) {
(*tm).clause_output[j] = 0;
tm->clause_output[j] = 0;
break;
}
}

(*tm).clause_output[j] = (*tm).clause_output[j] && !(predict == PREDICT && all_exclude == 1);
tm->clause_output[j] = tm->clause_output[j] && !(predict == PREDICT && all_exclude == 1);
}
}

Expand All @@ -112,9 +106,9 @@ static inline int sum_up_class_votes(struct TsetlinMachine *tm)
int class_sum = 0;
for (int j = 0; j < CLAUSES; j++) {
int sign = 1 - 2 * (j & 1);
class_sum += (*tm).clause_output[j]*sign;
class_sum += tm->clause_output[j]*sign;
}

class_sum = (class_sum > THRESHOLD) ? THRESHOLD : class_sum;
class_sum = (class_sum < -THRESHOLD) ? -THRESHOLD : class_sum;

Expand All @@ -133,22 +127,24 @@ int tm_get_state(struct TsetlinMachine *tm, int clause, int feature, int automat

static inline void type_i_feedback(struct TsetlinMachine *tm, int Xi[], int j, float s)
{
if ((*tm).clause_output[j] == 0) {
if (tm->clause_output[j] == 0) {
for (int k = 0; k < FEATURES; k++) {
(*tm).ta_state[j][k][0] -= ((*tm).ta_state[j][k][0] > 1) && (1.0*rand()/RAND_MAX <= 1.0/s);

(*tm).ta_state[j][k][1] -= ((*tm).ta_state[j][k][1] > 1) && (1.0*rand()/RAND_MAX <= 1.0/s);
int *features = tm->ta_state[j][k];
features[0] -= (features[0] > 1) && (1.0*rand()/RAND_MAX <= 1.0/s);

features[1] -= (features[1] > 1) && (1.0*rand()/RAND_MAX <= 1.0/s);
}
} else if ((*tm).clause_output[j] == 1) {
} else if (tm->clause_output[j] == 1) {
for (int k = 0; k < FEATURES; k++) {
int *features = tm->ta_state[j][k];
if (Xi[k] == 1) {
(*tm).ta_state[j][k][0] += ((*tm).ta_state[j][k][0] < NUMBER_OF_STATES*2) && (BOOST_TRUE_POSITIVE_FEEDBACK == 1 || 1.0*rand()/RAND_MAX <= (s-1)/s);
features[0] += (features[0] < NUMBER_OF_STATES*2) && (BOOST_TRUE_POSITIVE_FEEDBACK == 1 || 1.0*rand()/RAND_MAX <= (s-1)/s);

(*tm).ta_state[j][k][1] -= ((*tm).ta_state[j][k][1] > 1) && (1.0*rand()/RAND_MAX <= 1.0/s);
features[1] -= (features[1] > 1) && (1.0*rand()/RAND_MAX <= 1.0/s);
} else if (Xi[k] == 0) {
(*tm).ta_state[j][k][1] += ((*tm).ta_state[j][k][1] < NUMBER_OF_STATES*2) && (BOOST_TRUE_POSITIVE_FEEDBACK == 1 || 1.0*rand()/RAND_MAX <= (s-1)/s);
(*tm).ta_state[j][k][0] -= ((*tm).ta_state[j][k][0] > 1) && (1.0*rand()/RAND_MAX <= 1.0/s);
features[1] += (features[1] < NUMBER_OF_STATES*2) && (BOOST_TRUE_POSITIVE_FEEDBACK == 1 || 1.0*rand()/RAND_MAX <= (s-1)/s);

features[0] -= (features[0] > 1) && (1.0*rand()/RAND_MAX <= 1.0/s);
}
}
}
Expand All @@ -163,13 +159,14 @@ static inline void type_ii_feedback(struct TsetlinMachine *tm, int Xi[], int j)
int action_include;
int action_include_negated;

if ((*tm).clause_output[j] == 1) {
for (int k = 0; k < FEATURES; k++) {
action_include = action((*tm).ta_state[j][k][0]);
action_include_negated = action((*tm).ta_state[j][k][1]);
if (tm->clause_output[j] == 1) {
for (int k = 0; k < FEATURES; k++) {
int *features = tm->ta_state[j][k];
action_include = action(features[0]);
action_include_negated = action(features[1]);

(*tm).ta_state[j][k][0] += (action_include == 0 && (*tm).ta_state[j][k][0] < NUMBER_OF_STATES*2) && (Xi[k] == 0);
(*tm).ta_state[j][k][1] += (action_include_negated == 0 && (*tm).ta_state[j][k][1] < NUMBER_OF_STATES*2) && (Xi[k] == 1);
features[0] += (action_include == 0 && features[0] < NUMBER_OF_STATES*2) && (Xi[k] == 0);
features[1] += (action_include_negated == 0 && features[1] < NUMBER_OF_STATES*2) && (Xi[k] == 1);
}
}
}
Expand Down Expand Up @@ -200,17 +197,17 @@ void tm_update(struct TsetlinMachine *tm, int Xi[], int target, float s) {

// Calculate feedback to clauses
for (int j = 0; j < CLAUSES; j++) {
(*tm).feedback_to_clauses[j] = (2*target-1)*(1 - 2 * (j & 1))*(1.0*rand()/RAND_MAX <= (1.0/(THRESHOLD*2))*(THRESHOLD + (1 - 2*target)*class_sum));
tm->feedback_to_clauses[j] = (2*target-1)*(1 - 2 * (j & 1))*(1.0*rand()/RAND_MAX <= (1.0/(THRESHOLD*2))*(THRESHOLD + (1 - 2*target)*class_sum));
}

/*********************************/
/*** Train Individual Automata ***/
/*********************************/

for (int j = 0; j < CLAUSES; j++) {
if ((*tm).feedback_to_clauses[j] > 0) {
if (tm->feedback_to_clauses[j] > 0) {
type_i_feedback(tm, Xi, j, s);
} else if ((*tm).feedback_to_clauses[j] < 0) {
} else if (tm->feedback_to_clauses[j] < 0) {
type_ii_feedback(tm, Xi, j);
}
}
Expand Down
9 changes: 7 additions & 2 deletions TsetlinMachine.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,13 @@ This code implements the Tsetlin Machine from paper arXiv:1804.01508
#define PREDICT 1
#define UPDATE 0

struct TsetlinMachine {
int ta_state[CLAUSES][FEATURES][2];
typedef int TsetlinMachineFeatures_t[2];
typedef TsetlinMachineFeatures_t TsetlinMachineClauses_t[FEATURES];
typedef TsetlinMachineClauses_t TsetlinMachineState_t[CLAUSES];

struct TsetlinMachine {
/*int ta_state[CLAUSES][FEATURES][2]*/
TsetlinMachineState_t ta_state;

int clause_output[CLAUSES];

Expand Down