Skip to content

Commit 788e1a2

Browse files
committed
Update function to test integer splitting and move to utilities.hpp
1 parent e2e1e84 commit 788e1a2

File tree

2 files changed

+21
-25
lines changed

2 files changed

+21
-25
lines changed

include/gemmi.hpp

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -206,31 +206,6 @@ struct MatrixSplit {
206206
}
207207
};
208208

209-
template <typename splitint_t, typename accumulator_t, typename fp_t>
210-
std::vector<fp_t> mergeIntToFloats(const MatrixSplit<splitint_t, fp_t> &A,
211-
const size_t bitsPerSlice) {
212-
std::vector<fp_t> C (A.m * A.n, 0.0);
213-
214-
for (size_t i = 0; i < A.m; i++) {
215-
decltype(A.memory[0]) tmp = 0;
216-
for (size_t j = 0; j < A.n; j++) {
217-
int8_t shiftValue = computeNumFracBits<fp_t>() - bitsPerSlice;
218-
for (size_t iBlock = 0; iBlock < A.numSplits; iBlock++) {
219-
auto slice = A.memory[i + j * A.m + iBlock * A.m * A.n];
220-
auto new_slice = shiftValue > 0 ?
221-
slice << shiftValue :
222-
slice >> -shiftValue;
223-
tmp |= new_slice;
224-
shiftValue -= bitsPerSlice;
225-
}
226-
C[i + j * A.m] = std::ldexp(tmp, -(int)computeNumFracBits<fp_t>()) *
227-
A.powersVector[i];
228-
}
229-
}
230-
231-
return C;
232-
}
233-
234209
/* Compute exact products of slices of A and B. */
235210
template <typename splitint_t, typename accumulator_t, typename fp_t>
236211
void computeExactIntegerGEMM(const MatrixSplit<splitint_t, fp_t> &A,

include/utilities.hpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,27 @@
55
#include <iomanip>
66
#include <vector>
77

8+
template <typename splitint_t, typename fp_t>
9+
std::vector<fp_t> convertIntSlicesToFloatMatrix(const MatrixSplit<splitint_t, fp_t> &splitA,
10+
const size_t bitsPerSlice) {
11+
std::vector<fp_t> C (splitA.m * splitA.n, 0.0);
12+
13+
for (size_t i = 0; i < splitA.m; i++) {
14+
for (size_t j = 0; j < splitA.n; j++) {
15+
fp_t tmp = 0;
16+
for (size_t slice = 0; slice < splitA.numSplits; slice++) {
17+
fp_t currentSlice = splitA.memory[i + j * splitA.m + slice * splitA.m * splitA.n];
18+
tmp += ldexp(currentSlice, -(slice + 1) * bitsPerSlice);
19+
}
20+
size_t scalingIndex = splitA.dimension == normalisationDimension::byRows ? i : j;
21+
C[i + j * splitA.m] = tmp * splitA.powersVector[scalingIndex];
22+
assert(C[i + j * splitA.m] == ldexp(tmp, splitA.scalingExponents[scalingIndex]));
23+
}
24+
}
25+
26+
return C;
27+
}
28+
829
template <typename T>
930
void print_matrix(std::vector<T> A, const size_t m, const size_t n,
1031
const std::string id_string) {

0 commit comments

Comments
 (0)