@@ -21,14 +21,13 @@ namespace math {
21
21
template <typename T>
22
22
void MatrixBitCodeFunctor<T>::Add(framework::Tensor* tmat,
23
23
const framework::Tensor& vec) {
24
- SimpleCodeTable code_table (num_classes_);
25
24
size_t batch_size = tmat->dims ()[0 ];
26
25
size_t width = tmat->dims ()[1 ];
27
26
for (size_t i = 0 ; i < batch_size; ++i) {
28
- auto code = code_table ( static_cast < size_t >(ids_[i]) );
29
- int code_length = code. get_length ();
27
+ auto code = code_table-> get_code (i );
28
+ int code_length = code-> get_length ();
30
29
for (int j = 0 ; j < code_length; ++j) {
31
- size_t index = code. calc_index (j);
30
+ size_t index = code-> calc_index (j);
32
31
tmat->data <T>()[i * width + j] += vec.data <T>()[index];
33
32
}
34
33
}
@@ -37,14 +36,13 @@ void MatrixBitCodeFunctor<T>::Add(framework::Tensor* tmat,
37
36
template <typename T>
38
37
void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor& tmat,
39
38
framework::Tensor* vec) {
40
- SimpleCodeTable code_table (num_classes_);
41
39
size_t batch_size = tmat.dims ()[0 ];
42
40
size_t width = tmat.dims ()[1 ];
43
41
for (size_t i = 0 ; i < batch_size; ++i) {
44
- auto code = code_table ( static_cast < size_t >(ids_[i]) );
45
- int code_length = code. get_length ();
42
+ auto code = code_table-> get_code (i );
43
+ int code_length = code-> get_length ();
46
44
for (int j = 0 ; j < code_length; ++j) {
47
- size_t index = code. calc_index (j);
45
+ size_t index = code-> calc_index (j);
48
46
vec->data <T>()[index] += tmat.data <T>()[i * width + j];
49
47
}
50
48
}
@@ -53,15 +51,14 @@ void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor& tmat,
53
51
template <typename T>
54
52
void MatrixBitCodeFunctor<T>::Sum(const framework::Tensor& tmat,
55
53
framework::Tensor* sum, T scale_sum) {
56
- SimpleCodeTable code_table (num_classes_);
57
54
size_t num_samples = tmat.dims ()[0 ];
58
55
size_t o_width = tmat.dims ()[1 ];
59
56
for (size_t i = 0 ; i < num_samples; ++i) {
60
57
T sm = static_cast <T>(0.0 );
61
- auto code = code_table ( static_cast < size_t >(ids_[i]) );
62
- int code_length = code. get_length ();
58
+ auto code = code_table-> get_code (i );
59
+ int code_length = code-> get_length ();
63
60
for (int j = 0 ; j < code_length; ++j) {
64
- if (code. calc_bit (j)) {
61
+ if (code-> calc_bit (j)) {
65
62
// calc_bit starts from right most bit, while data in tmat[i] is in the
66
63
// reverse order.
67
64
sm += tmat.data <T>()[i * o_width + j];
@@ -75,7 +72,6 @@ template <typename T>
75
72
void MatrixBitCodeFunctor<T>::Mul(framework::Tensor* tmat,
76
73
const framework::Tensor& weight,
77
74
const framework::Tensor& input) {
78
- SimpleCodeTable code_table (num_classes_);
79
75
size_t num_samples = tmat->dims ()[0 ];
80
76
size_t tmat_width = tmat->dims ()[1 ];
81
77
size_t input_width = input.dims ()[1 ];
@@ -84,10 +80,10 @@ void MatrixBitCodeFunctor<T>::Mul(framework::Tensor* tmat,
84
80
auto weight_value = weight.data <T>();
85
81
auto input_value = input.data <T>();
86
82
for (size_t i = 0 ; i < num_samples; ++i) {
87
- auto code = code_table ( static_cast < size_t >(ids_[i]) );
88
- int code_length = code. get_length ();
83
+ auto code = code_table-> get_code (i );
84
+ int code_length = code-> get_length ();
89
85
for (int j = 0 ; j < code_length; ++j) {
90
- size_t index = code. calc_index (j);
86
+ size_t index = code-> calc_index (j);
91
87
T sum = static_cast <T>(0.0 );
92
88
for (size_t k = 0 ; k < input_width; ++k) {
93
89
sum += weight_value[weight_width * index + k] *
@@ -102,7 +98,6 @@ template <typename T>
102
98
void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
103
99
framework::Tensor* weight,
104
100
const framework::Tensor& input) {
105
- SimpleCodeTable code_table (num_classes_);
106
101
size_t num_samples = tmat.dims ()[0 ];
107
102
size_t input_width = input.dims ()[1 ];
108
103
size_t tmat_width = tmat.dims ()[1 ];
@@ -111,10 +106,10 @@ void MatrixBitCodeFunctor<T>::MulGradWeight(const framework::Tensor& tmat,
111
106
auto weight_value = weight->data <T>();
112
107
auto input_value = input.data <T>();
113
108
for (size_t i = 0 ; i < num_samples; ++i) {
114
- auto code = code_table ( static_cast < size_t >(ids_[i]) );
115
- int code_length = code. get_length ();
109
+ auto code = code_table-> get_code (i );
110
+ int code_length = code-> get_length ();
116
111
for (int j = 0 ; j < code_length; ++j) {
117
- size_t index = code. calc_index (j);
112
+ size_t index = code-> calc_index (j);
118
113
119
114
for (size_t k = 0 ; k < input_width; ++k) {
120
115
weight_value[weight_width * index + k] +=
@@ -128,7 +123,6 @@ template <typename T>
128
123
void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat,
129
124
const framework::Tensor& weight,
130
125
framework::Tensor* input) {
131
- SimpleCodeTable code_table (num_classes_);
132
126
size_t num_samples = tmat.dims ()[0 ];
133
127
size_t tmat_width = tmat.dims ()[1 ];
134
128
size_t input_width = input->dims ()[1 ];
@@ -138,10 +132,10 @@ void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat,
138
132
auto input_value = input->data <T>();
139
133
140
134
for (size_t i = 0 ; i < num_samples; ++i) {
141
- auto code = code_table ( static_cast < size_t >(ids_[i]) );
142
- int code_length = code. get_length ();
135
+ auto code = code_table-> get_code (i );
136
+ int code_length = code-> get_length ();
143
137
for (int j = 0 ; j < code_length; ++j) {
144
- size_t index = code. calc_index (j);
138
+ size_t index = code-> calc_index (j);
145
139
146
140
for (size_t k = 0 ; k < input_width; ++k) {
147
141
input_value[input_width * i + k] +=
@@ -154,14 +148,13 @@ void MatrixBitCodeFunctor<T>::MulGradError(const framework::Tensor& tmat,
154
148
155
149
template <typename T>
156
150
void MatrixBitCodeFunctor<T>::Sub(framework::Tensor* tmat) {
157
- SimpleCodeTable code_table (num_classes_);
158
151
size_t num_samples = tmat->dims ()[0 ];
159
152
size_t o_width = tmat->dims ()[1 ];
160
153
for (size_t i = 0 ; i < num_samples; ++i) {
161
- auto code = code_table ( static_cast < size_t >(ids_[i]) );
162
- int code_length = code. get_length ();
154
+ auto code = code_table-> get_code (i );
155
+ int code_length = code-> get_length ();
163
156
for (int j = 0 ; j < code_length; ++j) {
164
- if (code. calc_bit (j)) {
157
+ if (code-> calc_bit (j)) {
165
158
tmat->data <T>()[i * o_width + j] -= 1 ;
166
159
}
167
160
}
0 commit comments