50
50
51
51
#ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
52
52
53
+ #include " mlir/ExecutionEngine/SparseTensor/ArithmeticUtils.h"
53
54
#include " mlir/ExecutionEngine/SparseTensor/COO.h"
54
55
#include " mlir/ExecutionEngine/SparseTensor/ErrorHandling.h"
55
56
#include " mlir/ExecutionEngine/SparseTensor/File.h"
@@ -213,6 +214,47 @@ fromMLIRSparseTensor(const SparseTensorStorage<uint64_t, uint64_t, V> *tensor,
213
214
*pIndices = indices;
214
215
}
215
216
217
+ // ===----------------------------------------------------------------------===//
218
+ //
219
+ // Utilities for manipulating `StridedMemRefType`.
220
+ //
221
+ // ===----------------------------------------------------------------------===//
222
+
223
+ // We shouldn't need to use `detail::safelyEQ` here since the `1` is a literal.
224
+ #define ASSERT_NO_STRIDE (MEMREF ) \
225
+ do { \
226
+ assert ((MEMREF) && " Memref is nullptr" ); \
227
+ assert (((MEMREF)->strides [0 ] == 1 ) && " Memref has non-trivial stride" ); \
228
+ } while (false )
229
+
230
+ // All our functions use `uint64_t` for ranks, but `StridedMemRefType::sizes`
231
+ // uses `int64_t` on some platforms. So we explicitly cast this lookup to
232
+ // ensure we get a consistent type, and we use `checkOverflowCast` rather
233
+ // than `static_cast` just to be extremely sure that the casting can't
234
+ // go awry. (The cast should aways be safe since (1) sizes should never
235
+ // be negative, and (2) the maximum `int64_t` is smaller than the maximum
236
+ // `uint64_t`. But it's better to be safe than sorry.)
237
+ #define MEMREF_GET_USIZE (MEMREF ) \
238
+ detail::checkOverflowCast<uint64_t >((MEMREF)->sizes[0 ])
239
+
240
+ #define ASSERT_USIZE_EQ (MEMREF, SZ ) \
241
+ assert (detail::safelyEQ(MEMREF_GET_USIZE(MEMREF), (SZ)) && \
242
+ " Memref size mismatch" )
243
+
244
+ #define MEMREF_GET_PAYLOAD (MEMREF ) ((MEMREF)->data + (MEMREF)->offset)
245
+
246
+ // We make this a function rather than a macro mainly for type safety
247
+ // reasons. This function does not modify the vector, but it cannot
248
+ // be marked `const` because it is stored into the non-`const` memref.
249
+ template <typename T>
250
+ static void vectorToMemref (std::vector<T> &v, StridedMemRefType<T, 1 > &ref) {
251
+ ref.basePtr = ref.data = v.data ();
252
+ ref.offset = 0 ;
253
+ using SizeT = typename std::remove_reference_t <decltype (ref.sizes [0 ])>;
254
+ ref.sizes [0 ] = detail::checkOverflowCast<SizeT>(v.size ());
255
+ ref.strides [0 ] = 1 ;
256
+ }
257
+
216
258
} // anonymous namespace
217
259
218
260
extern " C" {
@@ -286,21 +328,21 @@ void *_mlir_ciface_newSparseTensor( // NOLINT
286
328
StridedMemRefType<index_type, 1 > *lvl2dimRef,
287
329
StridedMemRefType<index_type, 1 > *dim2lvlRef, OverheadType ptrTp,
288
330
OverheadType indTp, PrimaryType valTp, Action action, void *ptr) {
289
- assert (dimSizesRef && dimSizesRef-> strides [ 0 ] == 1 );
290
- assert (lvlSizesRef && lvlSizesRef-> strides [ 0 ] == 1 );
291
- assert (lvlTypesRef && lvlTypesRef-> strides [ 0 ] == 1 );
292
- assert (lvl2dimRef && lvl2dimRef-> strides [ 0 ] == 1 );
293
- assert (dim2lvlRef && dim2lvlRef-> strides [ 0 ] == 1 );
294
- const uint64_t dimRank = dimSizesRef-> sizes [ 0 ] ;
295
- const uint64_t lvlRank = lvlSizesRef-> sizes [ 0 ] ;
296
- assert (dim2lvlRef-> sizes [ 0 ] == ( int64_t ) dimRank);
297
- assert (lvlTypesRef-> sizes [ 0 ] == ( int64_t ) lvlRank &&
298
- lvl2dimRef-> sizes [ 0 ] == ( int64_t ) lvlRank);
299
- const index_type *dimSizes = dimSizesRef-> data + dimSizesRef-> offset ;
300
- const index_type *lvlSizes = lvlSizesRef-> data + lvlSizesRef-> offset ;
301
- const DimLevelType *lvlTypes = lvlTypesRef-> data + lvlTypesRef-> offset ;
302
- const index_type *lvl2dim = lvl2dimRef-> data + lvl2dimRef-> offset ;
303
- const index_type *dim2lvl = dim2lvlRef-> data + dim2lvlRef-> offset ;
331
+ ASSERT_NO_STRIDE (dimSizesRef);
332
+ ASSERT_NO_STRIDE (lvlSizesRef);
333
+ ASSERT_NO_STRIDE (lvlTypesRef);
334
+ ASSERT_NO_STRIDE (lvl2dimRef);
335
+ ASSERT_NO_STRIDE (dim2lvlRef);
336
+ const uint64_t dimRank = MEMREF_GET_USIZE ( dimSizesRef) ;
337
+ const uint64_t lvlRank = MEMREF_GET_USIZE ( lvlSizesRef) ;
338
+ ASSERT_USIZE_EQ (dim2lvlRef, dimRank);
339
+ ASSERT_USIZE_EQ (lvlTypesRef, lvlRank);
340
+ ASSERT_USIZE_EQ ( lvl2dimRef, lvlRank);
341
+ const index_type *dimSizes = MEMREF_GET_PAYLOAD ( dimSizesRef) ;
342
+ const index_type *lvlSizes = MEMREF_GET_PAYLOAD ( lvlSizesRef) ;
343
+ const DimLevelType *lvlTypes = MEMREF_GET_PAYLOAD ( lvlTypesRef) ;
344
+ const index_type *lvl2dim = MEMREF_GET_PAYLOAD ( lvl2dimRef) ;
345
+ const index_type *dim2lvl = MEMREF_GET_PAYLOAD ( dim2lvlRef) ;
304
346
305
347
// Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
306
348
// This is safe because of the static_assert above.
@@ -425,10 +467,8 @@ void *_mlir_ciface_newSparseTensor( // NOLINT
425
467
assert (ref &&tensor); \
426
468
std::vector<V> *v; \
427
469
static_cast <SparseTensorStorageBase *>(tensor)->getValues (&v); \
428
- ref->basePtr = ref->data = v->data (); \
429
- ref->offset = 0 ; \
430
- ref->sizes [0 ] = v->size (); \
431
- ref->strides [0 ] = 1 ; \
470
+ assert (v); \
471
+ vectorToMemref (*v, *ref); \
432
472
}
433
473
MLIR_SPARSETENSOR_FOREVERY_V (IMPL_SPARSEVALUES)
434
474
#undef IMPL_SPARSEVALUES
@@ -439,10 +479,8 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_SPARSEVALUES)
439
479
assert (ref &&tensor); \
440
480
std::vector<TYPE> *v; \
441
481
static_cast <SparseTensorStorageBase *>(tensor)->LIB (&v, d); \
442
- ref->basePtr = ref->data = v->data (); \
443
- ref->offset = 0 ; \
444
- ref->sizes [0 ] = v->size (); \
445
- ref->strides [0 ] = 1 ; \
482
+ assert (v); \
483
+ vectorToMemref (*v, *ref); \
446
484
}
447
485
#define IMPL_SPARSEPOINTERS (PNAME, P ) \
448
486
IMPL_GETOVERHEAD (sparsePointers##PNAME, P, getPointers)
@@ -463,16 +501,17 @@ MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSEINDICES)
463
501
void *lvlCOO, StridedMemRefType<V, 0 > *vref, \
464
502
StridedMemRefType<index_type, 1 > *dimIndRef, \
465
503
StridedMemRefType<index_type, 1 > *dim2lvlRef) { \
466
- assert (lvlCOO &&vref &&dimIndRef &&dim2lvlRef); \
467
- assert (dimIndRef->strides [0 ] == 1 && dim2lvlRef->strides [0 ] == 1 ); \
468
- const uint64_t rank = dimIndRef->sizes [0 ]; \
469
- assert (dim2lvlRef->sizes [0 ] == (int64_t )rank); \
470
- const index_type *dimInd = dimIndRef->data + dimIndRef->offset ; \
471
- const index_type *dim2lvl = dim2lvlRef->data + dim2lvlRef->offset ; \
504
+ assert (lvlCOO &&vref); \
505
+ ASSERT_NO_STRIDE (dimIndRef); \
506
+ ASSERT_NO_STRIDE (dim2lvlRef); \
507
+ const uint64_t rank = MEMREF_GET_USIZE (dimIndRef); \
508
+ ASSERT_USIZE_EQ (dim2lvlRef, rank); \
509
+ const index_type *dimInd = MEMREF_GET_PAYLOAD (dimIndRef); \
510
+ const index_type *dim2lvl = MEMREF_GET_PAYLOAD (dim2lvlRef); \
472
511
std::vector<index_type> lvlInd (rank); \
473
512
for (uint64_t d = 0 ; d < rank; ++d) \
474
513
lvlInd[dim2lvl[d]] = dimInd[d]; \
475
- V *value = vref-> data + vref-> offset ; \
514
+ V *value = MEMREF_GET_PAYLOAD ( vref); \
476
515
static_cast <SparseTensorCOO<V> *>(lvlCOO)->add (lvlInd, *value); \
477
516
return lvlCOO; \
478
517
}
@@ -483,11 +522,11 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_ADDELT)
483
522
bool _mlir_ciface_getNext##VNAME(void *iter, \
484
523
StridedMemRefType<index_type, 1 > *iref, \
485
524
StridedMemRefType<V, 0 > *vref) { \
486
- assert (iter &&iref && vref); \
487
- assert (iref-> strides [ 0 ] == 1 ); \
488
- index_type *indx = iref-> data + iref-> offset ; \
489
- V *value = vref-> data + vref-> offset ; \
490
- const uint64_t isize = iref-> sizes [ 0 ]; \
525
+ assert (iter &&vref); \
526
+ ASSERT_NO_STRIDE (iref); \
527
+ index_type *indx = MEMREF_GET_PAYLOAD ( iref); \
528
+ V *value = MEMREF_GET_PAYLOAD ( vref); \
529
+ const uint64_t isize = MEMREF_GET_USIZE ( iref); \
491
530
const Element<V> *elem = \
492
531
static_cast <SparseTensorIterator<V> *>(iter)->getNext (); \
493
532
if (elem == nullptr ) \
@@ -504,11 +543,11 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETNEXT)
504
543
void _mlir_ciface_lexInsert##VNAME(void *tensor, \
505
544
StridedMemRefType<index_type, 1 > *cref, \
506
545
StridedMemRefType<V, 0 > *vref) { \
507
- assert (tensor &&cref && vref); \
508
- assert (cref-> strides [ 0 ] == 1 ); \
509
- index_type *cursor = cref-> data + cref-> offset ; \
546
+ assert (tensor &&vref); \
547
+ ASSERT_NO_STRIDE (cref); \
548
+ index_type *cursor = MEMREF_GET_PAYLOAD ( cref); \
510
549
assert (cursor); \
511
- V *value = vref-> data + vref-> offset ; \
550
+ V *value = MEMREF_GET_PAYLOAD ( vref); \
512
551
static_cast <SparseTensorStorageBase *>(tensor)->lexInsert (cursor, *value); \
513
552
}
514
553
MLIR_SPARSETENSOR_FOREVERY_V (IMPL_LEXINSERT)
@@ -519,16 +558,16 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_LEXINSERT)
519
558
void *tensor, StridedMemRefType<index_type, 1 > *cref, \
520
559
StridedMemRefType<V, 1 > *vref, StridedMemRefType<bool , 1 > *fref, \
521
560
StridedMemRefType<index_type, 1 > *aref, index_type count) { \
522
- assert (tensor &&cref &&vref &&fref &&aref); \
523
- assert (cref-> strides [ 0 ] == 1 ); \
524
- assert (vref-> strides [ 0 ] == 1 ); \
525
- assert (fref-> strides [ 0 ] == 1 ); \
526
- assert (aref-> strides [ 0 ] == 1 ); \
527
- assert (vref-> sizes [ 0 ] == fref-> sizes [ 0 ]); \
528
- index_type *cursor = cref-> data + cref-> offset ; \
529
- V *values = vref-> data + vref-> offset ; \
530
- bool *filled = fref-> data + fref-> offset ; \
531
- index_type *added = aref-> data + aref-> offset ; \
561
+ assert (tensor); \
562
+ ASSERT_NO_STRIDE (cref); \
563
+ ASSERT_NO_STRIDE (vref); \
564
+ ASSERT_NO_STRIDE (fref); \
565
+ ASSERT_NO_STRIDE (aref); \
566
+ ASSERT_USIZE_EQ (vref, MEMREF_GET_USIZE ( fref)); \
567
+ index_type *cursor = MEMREF_GET_PAYLOAD ( cref); \
568
+ V *values = MEMREF_GET_PAYLOAD ( vref); \
569
+ bool *filled = MEMREF_GET_PAYLOAD ( fref); \
570
+ index_type *added = MEMREF_GET_PAYLOAD ( aref); \
532
571
static_cast <SparseTensorStorageBase *>(tensor)->expInsert ( \
533
572
cursor, values, filled, added, count); \
534
573
}
@@ -537,26 +576,26 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_EXPINSERT)
537
576
538
577
void _mlir_ciface_getSparseTensorReaderDimSizes (
539
578
void *p, StridedMemRefType<index_type, 1 > *dref) {
540
- assert (p && dref );
541
- assert (dref-> strides [ 0 ] == 1 );
542
- index_type *dimSizes = dref-> data + dref-> offset ;
579
+ assert (p);
580
+ ASSERT_NO_STRIDE (dref);
581
+ index_type *dimSizes = MEMREF_GET_PAYLOAD ( dref) ;
543
582
SparseTensorReader &file = *static_cast <SparseTensorReader *>(p);
544
583
const index_type *sizes = file.getDimSizes ();
545
584
index_type rank = file.getRank ();
546
- for (uint64_t r = 0 ; r < rank; ++r)
585
+ for (index_type r = 0 ; r < rank; ++r)
547
586
dimSizes[r] = sizes[r];
548
587
}
549
588
550
589
#define IMPL_GETNEXT (VNAME, V ) \
551
590
void _mlir_ciface_getSparseTensorReaderNext##VNAME( \
552
591
void *p, StridedMemRefType<index_type, 1 > *iref, \
553
592
StridedMemRefType<V, 0 > *vref) { \
554
- assert (p &&iref && vref); \
555
- assert (iref-> strides [ 0 ] == 1 ); \
556
- index_type *indices = iref-> data + iref-> offset ; \
593
+ assert (p &&vref); \
594
+ ASSERT_NO_STRIDE (iref); \
595
+ index_type *indices = MEMREF_GET_PAYLOAD ( iref); \
557
596
SparseTensorReader *stfile = static_cast <SparseTensorReader *>(p); \
558
597
index_type rank = stfile->getRank (); \
559
- V *value = vref-> data + vref-> offset ; \
598
+ V *value = MEMREF_GET_PAYLOAD ( vref); \
560
599
*value = stfile->readCOOElement <V>(rank, indices); \
561
600
}
562
601
MLIR_SPARSETENSOR_FOREVERY_V (IMPL_GETNEXT)
@@ -565,10 +604,10 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETNEXT)
565
604
void _mlir_ciface_outSparseTensorWriterMetaData (
566
605
void *p, index_type rank, index_type nnz,
567
606
StridedMemRefType<index_type, 1 > *dref) {
568
- assert (p && dref );
569
- assert (dref-> strides [ 0 ] == 1 );
607
+ assert (p);
608
+ ASSERT_NO_STRIDE (dref);
570
609
assert (rank != 0 );
571
- index_type *dimSizes = dref-> data + dref-> offset ;
610
+ index_type *dimSizes = MEMREF_GET_PAYLOAD ( dref) ;
572
611
SparseTensorWriter &file = *static_cast <SparseTensorWriter *>(p);
573
612
file << rank << " " << nnz << std::endl;
574
613
for (index_type r = 0 ; r < rank - 1 ; ++r)
@@ -580,13 +619,13 @@ void _mlir_ciface_outSparseTensorWriterMetaData(
580
619
void _mlir_ciface_outSparseTensorWriterNext##VNAME( \
581
620
void *p, index_type rank, StridedMemRefType<index_type, 1 > *iref, \
582
621
StridedMemRefType<V, 0 > *vref) { \
583
- assert (p &&iref && vref); \
584
- assert (iref-> strides [ 0 ] == 1 ); \
585
- index_type *indices = iref-> data + iref-> offset ; \
622
+ assert (p &&vref); \
623
+ ASSERT_NO_STRIDE (iref); \
624
+ index_type *indices = MEMREF_GET_PAYLOAD ( iref); \
586
625
SparseTensorWriter &file = *static_cast <SparseTensorWriter *>(p); \
587
- for (uint64_t r = 0 ; r < rank; ++r) \
626
+ for (index_type r = 0 ; r < rank; ++r) \
588
627
file << (indices[r] + 1 ) << " " ; \
589
- V *value = vref-> data + vref-> offset ; \
628
+ V *value = MEMREF_GET_PAYLOAD ( vref); \
590
629
file << *value << std::endl; \
591
630
}
592
631
MLIR_SPARSETENSOR_FOREVERY_V (IMPL_OUTNEXT)
@@ -723,4 +762,9 @@ void delSparseTensorWriter(void *p) {
723
762
724
763
} // extern "C"
725
764
765
+ #undef MEMREF_GET_PAYLOAD
766
+ #undef ASSERT_USIZE_EQ
767
+ #undef MEMREF_GET_USIZE
768
+ #undef ASSERT_NO_STRIDE
769
+
726
770
#endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
0 commit comments