File tree Expand file tree Collapse file tree 2 files changed +21
-25
lines changed
Expand file tree Collapse file tree 2 files changed +21
-25
lines changed Original file line number Diff line number Diff line change @@ -206,31 +206,6 @@ struct MatrixSplit {
206206 }
207207};
208208
209- template <typename splitint_t , typename accumulator_t , typename fp_t >
210- std::vector<fp_t > mergeIntToFloats (const MatrixSplit<splitint_t , fp_t > &A,
211- const size_t bitsPerSlice) {
212- std::vector<fp_t > C (A.m * A.n , 0.0 );
213-
214- for (size_t i = 0 ; i < A.m ; i++) {
215- decltype (A.memory [0 ]) tmp = 0 ;
216- for (size_t j = 0 ; j < A.n ; j++) {
217- int8_t shiftValue = computeNumFracBits<fp_t >() - bitsPerSlice;
218- for (size_t iBlock = 0 ; iBlock < A.numSplits ; iBlock++) {
219- auto slice = A.memory [i + j * A.m + iBlock * A.m * A.n ];
220- auto new_slice = shiftValue > 0 ?
221- slice << shiftValue :
222- slice >> -shiftValue;
223- tmp |= new_slice;
224- shiftValue -= bitsPerSlice;
225- }
226- C[i + j * A.m ] = std::ldexp (tmp, -(int )computeNumFracBits<fp_t >()) *
227- A.powersVector [i];
228- }
229- }
230-
231- return C;
232- }
233-
234209/* Compute exact products of slices of A and B. */
235210template <typename splitint_t , typename accumulator_t , typename fp_t >
236211void computeExactIntegerGEMM (const MatrixSplit<splitint_t , fp_t > &A,
Original file line number Diff line number Diff line change 55#include < iomanip>
66#include < vector>
77
8+ template <typename splitint_t , typename fp_t >
9+ std::vector<fp_t > convertIntSlicesToFloatMatrix (const MatrixSplit<splitint_t , fp_t > &splitA,
10+ const size_t bitsPerSlice) {
11+ std::vector<fp_t > C (splitA.m * splitA.n , 0.0 );
12+
13+ for (size_t i = 0 ; i < splitA.m ; i++) {
14+ for (size_t j = 0 ; j < splitA.n ; j++) {
15+ fp_t tmp = 0 ;
16+ for (size_t slice = 0 ; slice < splitA.numSplits ; slice++) {
17+ fp_t currentSlice = splitA.memory [i + j * splitA.m + slice * splitA.m * splitA.n ];
18+ tmp += ldexp (currentSlice, -(slice + 1 ) * bitsPerSlice);
19+ }
20+ size_t scalingIndex = splitA.dimension == normalisationDimension::byRows ? i : j;
21+ C[i + j * splitA.m ] = tmp * splitA.powersVector [scalingIndex];
22+ assert (C[i + j * splitA.m ] == ldexp (tmp, splitA.scalingExponents [scalingIndex]));
23+ }
24+ }
25+
26+ return C;
27+ }
28+
829template <typename T>
930void print_matrix (std::vector<T> A, const size_t m, const size_t n,
1031 const std::string id_string) {
You can’t perform that action at this time.
0 commit comments