2020#ifndef MLIR_EXECUTIONENGINE_SPARSETENSOR_FILE_H
2121#define MLIR_EXECUTIONENGINE_SPARSETENSOR_FILE_H
2222
23+ #include " mlir/ExecutionEngine/SparseTensor/MapRef.h"
2324#include " mlir/ExecutionEngine/SparseTensor/Storage.h"
2425
2526#include < fstream>
@@ -75,6 +76,10 @@ inline V readValue(char **linePtr, bool isPattern) {
7576
7677} // namespace detail
7778
79+ // ===----------------------------------------------------------------------===//
80+ //
81+ // Reader class.
82+ //
7883// ===----------------------------------------------------------------------===//
7984
8085// / This class abstracts over the information stored in file headers,
@@ -132,6 +137,7 @@ class SparseTensorReader final {
132137 // / Reads and parses the file's header.
133138 void readHeader ();
134139
140+ // / Returns the stored value kind.
135141 ValueKind getValueKind () const { return valueKind_; }
136142
137143 // / Checks if a header has been successfully read.
@@ -185,58 +191,37 @@ class SparseTensorReader final {
185191 // / valid after parsing the header.
186192 void assertMatchesShape (uint64_t rank, const uint64_t *shape) const ;
187193
188- // / Reads a sparse tensor element from the next line in the input file and
189- // / returns the value of the element. Stores the coordinates of the element
190- // / to the `dimCoords` array.
191- template <typename V>
192- V readElement (uint64_t dimRank, uint64_t *dimCoords) {
193- assert (dimRank == getRank () && " rank mismatch" );
194- char *linePtr = readCoords (dimCoords);
195- return detail::readValue<V>(&linePtr, isPattern ());
196- }
197-
198- // / Allocates a new COO object for `lvlSizes`, initializes it by reading
199- // / all the elements from the file and applying `dim2lvl` to their
200- // / dim-coordinates, and then closes the file. Templated on V only.
201- template <typename V>
202- SparseTensorCOO<V> *readCOO (uint64_t lvlRank, const uint64_t *lvlSizes,
203- const uint64_t *dim2lvl);
204-
205194 // / Allocates a new sparse-tensor storage object with the given encoding,
206195 // / initializes it by reading all the elements from the file, and then
207196 // / closes the file. Templated on P, I, and V.
208197 template <typename P, typename I, typename V>
209198 SparseTensorStorage<P, I, V> *
210199 readSparseTensor (uint64_t lvlRank, const uint64_t *lvlSizes,
211- const DimLevelType *lvlTypes, const uint64_t *lvl2dim,
212- const uint64_t *dim2lvl) {
213- auto *lvlCOO = readCOO<V>(lvlRank, lvlSizes, dim2lvl);
200+ const DimLevelType *lvlTypes, const uint64_t *dim2lvl,
201+ const uint64_t *lvl2dim) {
202+ const uint64_t dimRank = getRank ();
203+ MapRef map (dimRank, lvlRank, dim2lvl, lvl2dim);
204+ auto *coo = readCOO<V>(map, lvlSizes);
214205 auto *tensor = SparseTensorStorage<P, I, V>::newFromCOO (
215- getRank () , getDimSizes (), lvlRank, lvlTypes, lvl2dim, *lvlCOO );
216- delete lvlCOO ;
206+ dimRank , getDimSizes (), lvlRank, lvlTypes, lvl2dim, *coo );
207+ delete coo ;
217208 return tensor;
218209 }
219210
220211 // / Reads the COO tensor from the file, stores the coordinates and values to
221212 // / the given buffers, returns a boolean value to indicate whether the COO
222213 // / elements are sorted.
223- // / Precondition: the buffers should have enough space to hold the elements.
224214 template <typename C, typename V>
225215 bool readToBuffers (uint64_t lvlRank, const uint64_t *dim2lvl,
226- C *lvlCoordinates, V *values);
216+ const uint64_t *lvl2dim, C *lvlCoordinates, V *values);
227217
228218private:
229- // / Attempts to read a line from the file. Is private because there's
230- // / no reason for client code to call it.
219+ // / Attempts to read a line from the file.
231220 void readLine ();
232221
233222 // / Reads the next line of the input file and parses the coordinates
234223 // / into the `dimCoords` argument. Returns the position in the `line`
235- // / buffer where the element's value should be parsed from. This method
236- // / has been factored out from `readElement` to minimize code bloat
237- // / for the generated library.
238- // /
239- // / Precondition: `dimCoords` is valid for `getRank()`.
224+ // / buffer where the element's value should be parsed from.
240225 template <typename C>
241226 char *readCoords (C *dimCoords) {
242227 readLine ();
@@ -251,24 +236,20 @@ class SparseTensorReader final {
251236 return linePtr;
252237 }
253238
254- // / The internal implementation of `readCOO`. We template over
255- // / `IsPattern` in order to perform LICM without needing to duplicate the
256- // / source code.
257- //
258- // TODO: We currently take the `dim2lvl` argument as a `PermutationRef`
259- // since that's what `readCOO` creates. Once we update `readCOO` to
260- // functionalize the mapping, then this helper will just take that
261- // same function.
239+ // / Reads all the elements from the file while applying the given map.
240+ template <typename V>
241+ SparseTensorCOO<V> *readCOO (const MapRef &map, const uint64_t *lvlSizes);
242+
243+ // / The implementation of `readCOO` that is templated `IsPattern` in order
244+ // / to perform LICM without needing to duplicate the source code.
262245 template <typename V, bool IsPattern>
263- void readCOOLoop (uint64_t lvlRank, detail::PermutationRef dim2lvl,
264- SparseTensorCOO<V> *lvlCOO);
246+ void readCOOLoop (const MapRef &map, SparseTensorCOO<V> *coo);
265247
266- // / The internal implementation of `readToBuffers`. We template over
267- // / `IsPattern` in order to perform LICM without needing to duplicate the
268- // / source code.
248+ // / The internal implementation of `readToBuffers`. We template over
249+ // / `IsPattern` in order to perform LICM without needing to duplicate
250+ // / the source code.
269251 template <typename C, typename V, bool IsPattern>
270- bool readToBuffersLoop (uint64_t lvlRank, detail::PermutationRef dim2lvl,
271- C *lvlCoordinates, V *values);
252+ bool readToBuffersLoop (const MapRef &map, C *lvlCoordinates, V *values);
272253
273254 // / Reads the MME header of a general sparse matrix of type real.
274255 void readMMEHeader ();
@@ -288,96 +269,76 @@ class SparseTensorReader final {
288269 char line[kColWidth ];
289270};
290271
272+ // ===----------------------------------------------------------------------===//
273+ //
274+ // Reader class methods.
275+ //
291276// ===----------------------------------------------------------------------===//
292277
293278template <typename V>
294- SparseTensorCOO<V> *SparseTensorReader::readCOO (uint64_t lvlRank,
295- const uint64_t *lvlSizes,
296- const uint64_t *dim2lvl) {
279+ SparseTensorCOO<V> *SparseTensorReader::readCOO (const MapRef &map,
280+ const uint64_t *lvlSizes) {
297281 assert (isValid () && " Attempt to readCOO() before readHeader()" );
298- const uint64_t dimRank = getRank ();
299- assert (lvlRank == dimRank && " Rank mismatch" );
300- detail::PermutationRef d2l (dimRank, dim2lvl);
301282 // Prepare a COO object with the number of stored elems as initial capacity.
302- auto *lvlCOO = new SparseTensorCOO<V>(lvlRank, lvlSizes, getNSE ());
303- // Do some manual LICM, to avoid assertions in the for-loop.
304- const bool IsPattern = isPattern ();
305- if (IsPattern)
306- readCOOLoop<V, true >(lvlRank, d2l, lvlCOO);
283+ auto *coo = new SparseTensorCOO<V>(map.getLvlRank (), lvlSizes, getNSE ());
284+ // Enter the reading loop.
285+ if (isPattern ())
286+ readCOOLoop<V, true >(map, coo);
307287 else
308- readCOOLoop<V, false >(lvlRank, d2l, lvlCOO );
288+ readCOOLoop<V, false >(map, coo );
309289 // Close the file and return the COO.
310290 closeFile ();
311- return lvlCOO ;
291+ return coo ;
312292}
313293
314294template <typename V, bool IsPattern>
315- void SparseTensorReader::readCOOLoop (uint64_t lvlRank,
316- detail::PermutationRef dim2lvl,
317- SparseTensorCOO<V> *lvlCOO) {
318- const uint64_t dimRank = getRank ();
295+ void SparseTensorReader::readCOOLoop (const MapRef &map,
296+ SparseTensorCOO<V> *coo) {
297+ const uint64_t dimRank = map.getDimRank ();
298+ const uint64_t lvlRank = map.getLvlRank ();
299+ assert (dimRank == getRank ());
319300 std::vector<uint64_t > dimCoords (dimRank);
320301 std::vector<uint64_t > lvlCoords (lvlRank);
321- for (uint64_t nse = getNSE (), k = 0 ; k < nse; ++k) {
322- // We inline `readElement` here in order to avoid redundant
323- // assertions, since they're guaranteed by the call to `isValid()`
324- // and the construction of `dimCoords` above.
302+ for (uint64_t k = 0 , nse = getNSE (); k < nse; k++) {
325303 char *linePtr = readCoords (dimCoords.data ());
326304 const V value = detail::readValue<V, IsPattern>(&linePtr);
327- dim2lvl.pushforward (dimRank, dimCoords.data (), lvlCoords.data ());
328- // TODO: <https://github.com/llvm/llvm-project/issues/54179>
329- lvlCOO->add (lvlCoords, value);
305+ map.pushforward (dimCoords.data (), lvlCoords.data ());
306+ coo->add (lvlCoords, value);
330307 }
331308}
332309
333310template <typename C, typename V>
334311bool SparseTensorReader::readToBuffers (uint64_t lvlRank,
335312 const uint64_t *dim2lvl,
313+ const uint64_t *lvl2dim,
336314 C *lvlCoordinates, V *values) {
337315 assert (isValid () && " Attempt to readCOO() before readHeader()" );
338- // Construct a `PermutationRef` for the `pushforward` below.
339- // TODO: This specific implementation does not generalize to arbitrary
340- // mappings, but once we functionalize the `dim2lvl` argument we can
341- // simply use that function instead.
342- const uint64_t dimRank = getRank ();
343- assert (lvlRank == dimRank && " Rank mismatch" );
344- detail::PermutationRef d2l (dimRank, dim2lvl);
345- // Do some manual LICM, to avoid assertions in the for-loop.
316+ MapRef map (getRank (), lvlRank, dim2lvl, lvl2dim);
346317 bool isSorted =
347- isPattern ()
348- ? readToBuffersLoop<C, V, true >(lvlRank, d2l, lvlCoordinates, values)
349- : readToBuffersLoop<C, V, false >(lvlRank, d2l, lvlCoordinates,
350- values);
351-
352- // Close the file and return isSorted.
318+ isPattern () ? readToBuffersLoop<C, V, true >(map, lvlCoordinates, values)
319+ : readToBuffersLoop<C, V, false >(map, lvlCoordinates, values);
353320 closeFile ();
354321 return isSorted;
355322}
356323
357324template <typename C, typename V, bool IsPattern>
358- bool SparseTensorReader::readToBuffersLoop (uint64_t lvlRank ,
359- detail::PermutationRef dim2lvl,
360- C *lvlCoordinates, V *values) {
361- const uint64_t dimRank = getRank ();
325+ bool SparseTensorReader::readToBuffersLoop (const MapRef &map, C *lvlCoordinates ,
326+ V *values) {
327+ const uint64_t dimRank = map. getDimRank ();
328+ const uint64_t lvlRank = map. getLvlRank ();
362329 const uint64_t nse = getNSE ();
330+ assert (dimRank == getRank ());
363331 std::vector<C> dimCoords (dimRank);
364- // Read the first element with isSorted=false as a way to avoid accessing its
365- // previous element.
366332 bool isSorted = false ;
367333 char *linePtr;
368- // We inline `readElement` here in order to avoid redundant assertions,
369- // since they're guaranteed by the call to `isValid()` and the construction
370- // of `dimCoords` above.
371334 const auto readNextElement = [&]() {
372335 linePtr = readCoords<C>(dimCoords.data ());
373- dim2lvl .pushforward (dimRank, dimCoords.data (), lvlCoordinates);
336+ map .pushforward (dimCoords.data (), lvlCoordinates);
374337 *values = detail::readValue<V, IsPattern>(&linePtr);
375338 if (isSorted) {
376- // Note that isSorted was set to false while reading the first element,
339+ // Note that isSorted is set to false when reading the first element,
377340 // to guarantee the safeness of using prevLvlCoords.
378341 C *prevLvlCoords = lvlCoordinates - lvlRank;
379- // TODO: define a new CoordsLT which is like ElementLT but doesn't have
380- // the V parameter, and use it here.
381342 for (uint64_t l = 0 ; l < lvlRank; ++l) {
382343 if (prevLvlCoords[l] != lvlCoordinates[l]) {
383344 if (prevLvlCoords[l] > lvlCoordinates[l])
@@ -393,7 +354,6 @@ bool SparseTensorReader::readToBuffersLoop(uint64_t lvlRank,
393354 isSorted = true ;
394355 for (uint64_t n = 1 ; n < nse; ++n)
395356 readNextElement ();
396-
397357 return isSorted;
398358}
399359
0 commit comments