Skip to content

Commit a3e4888

Browse files
committed
[mlir][sparse] Macros to clean up StridedMemRefType in the SparseTensorRuntime
In particular, this silences warnings from [-Wsign-compare]. This is a revised version of D137735, which got reverted due to a sign-comparison warning on LLVM's Windows buildbot (which was not on MLIR's Windows buildbot). Differences vs the previous differential: * `vectorToMemref` now uses `detail::checkOverflowCast` to silence the warning that caused the the previous differential to get reverted. * `MEMREF_GET_USIZE` now uses `detail::checkOverflowCast` rather than `static_cast` * `ASSERT_USIZE_EQ` added to abbreviate another common idiom, and to ensure that we use `detail::safelyEQ` everywhere (to silence a few other warnings) * A couple for-loops now use `index_type` for the induction variable, since their upper bound uses that typedef too. (Namely `_mlir_ciface_getSparseTensorReaderDimSizes` and `_mlir_ciface_outSparseTensorWriterNext`) Depends on D138149 Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D137998
1 parent 84ef723 commit a3e4888

File tree

1 file changed

+109
-65
lines changed

1 file changed

+109
-65
lines changed

mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp

Lines changed: 109 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050

5151
#ifdef MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS
5252

53+
#include "mlir/ExecutionEngine/SparseTensor/ArithmeticUtils.h"
5354
#include "mlir/ExecutionEngine/SparseTensor/COO.h"
5455
#include "mlir/ExecutionEngine/SparseTensor/ErrorHandling.h"
5556
#include "mlir/ExecutionEngine/SparseTensor/File.h"
@@ -213,6 +214,47 @@ fromMLIRSparseTensor(const SparseTensorStorage<uint64_t, uint64_t, V> *tensor,
213214
*pIndices = indices;
214215
}
215216

217+
//===----------------------------------------------------------------------===//
218+
//
219+
// Utilities for manipulating `StridedMemRefType`.
220+
//
221+
//===----------------------------------------------------------------------===//
222+
223+
// We shouldn't need to use `detail::safelyEQ` here since the `1` is a literal.
224+
#define ASSERT_NO_STRIDE(MEMREF) \
225+
do { \
226+
assert((MEMREF) && "Memref is nullptr"); \
227+
assert(((MEMREF)->strides[0] == 1) && "Memref has non-trivial stride"); \
228+
} while (false)
229+
230+
// All our functions use `uint64_t` for ranks, but `StridedMemRefType::sizes`
231+
// uses `int64_t` on some platforms. So we explicitly cast this lookup to
232+
// ensure we get a consistent type, and we use `checkOverflowCast` rather
233+
// than `static_cast` just to be extremely sure that the casting can't
234+
// go awry. (The cast should aways be safe since (1) sizes should never
235+
// be negative, and (2) the maximum `int64_t` is smaller than the maximum
236+
// `uint64_t`. But it's better to be safe than sorry.)
237+
#define MEMREF_GET_USIZE(MEMREF) \
238+
detail::checkOverflowCast<uint64_t>((MEMREF)->sizes[0])
239+
240+
#define ASSERT_USIZE_EQ(MEMREF, SZ) \
241+
assert(detail::safelyEQ(MEMREF_GET_USIZE(MEMREF), (SZ)) && \
242+
"Memref size mismatch")
243+
244+
#define MEMREF_GET_PAYLOAD(MEMREF) ((MEMREF)->data + (MEMREF)->offset)
245+
246+
// We make this a function rather than a macro mainly for type safety
247+
// reasons. This function does not modify the vector, but it cannot
248+
// be marked `const` because it is stored into the non-`const` memref.
249+
template <typename T>
250+
static void vectorToMemref(std::vector<T> &v, StridedMemRefType<T, 1> &ref) {
251+
ref.basePtr = ref.data = v.data();
252+
ref.offset = 0;
253+
using SizeT = typename std::remove_reference_t<decltype(ref.sizes[0])>;
254+
ref.sizes[0] = detail::checkOverflowCast<SizeT>(v.size());
255+
ref.strides[0] = 1;
256+
}
257+
216258
} // anonymous namespace
217259

218260
extern "C" {
@@ -286,21 +328,21 @@ void *_mlir_ciface_newSparseTensor( // NOLINT
286328
StridedMemRefType<index_type, 1> *lvl2dimRef,
287329
StridedMemRefType<index_type, 1> *dim2lvlRef, OverheadType ptrTp,
288330
OverheadType indTp, PrimaryType valTp, Action action, void *ptr) {
289-
assert(dimSizesRef && dimSizesRef->strides[0] == 1);
290-
assert(lvlSizesRef && lvlSizesRef->strides[0] == 1);
291-
assert(lvlTypesRef && lvlTypesRef->strides[0] == 1);
292-
assert(lvl2dimRef && lvl2dimRef->strides[0] == 1);
293-
assert(dim2lvlRef && dim2lvlRef->strides[0] == 1);
294-
const uint64_t dimRank = dimSizesRef->sizes[0];
295-
const uint64_t lvlRank = lvlSizesRef->sizes[0];
296-
assert(dim2lvlRef->sizes[0] == (int64_t)dimRank);
297-
assert(lvlTypesRef->sizes[0] == (int64_t)lvlRank &&
298-
lvl2dimRef->sizes[0] == (int64_t)lvlRank);
299-
const index_type *dimSizes = dimSizesRef->data + dimSizesRef->offset;
300-
const index_type *lvlSizes = lvlSizesRef->data + lvlSizesRef->offset;
301-
const DimLevelType *lvlTypes = lvlTypesRef->data + lvlTypesRef->offset;
302-
const index_type *lvl2dim = lvl2dimRef->data + lvl2dimRef->offset;
303-
const index_type *dim2lvl = dim2lvlRef->data + dim2lvlRef->offset;
331+
ASSERT_NO_STRIDE(dimSizesRef);
332+
ASSERT_NO_STRIDE(lvlSizesRef);
333+
ASSERT_NO_STRIDE(lvlTypesRef);
334+
ASSERT_NO_STRIDE(lvl2dimRef);
335+
ASSERT_NO_STRIDE(dim2lvlRef);
336+
const uint64_t dimRank = MEMREF_GET_USIZE(dimSizesRef);
337+
const uint64_t lvlRank = MEMREF_GET_USIZE(lvlSizesRef);
338+
ASSERT_USIZE_EQ(dim2lvlRef, dimRank);
339+
ASSERT_USIZE_EQ(lvlTypesRef, lvlRank);
340+
ASSERT_USIZE_EQ(lvl2dimRef, lvlRank);
341+
const index_type *dimSizes = MEMREF_GET_PAYLOAD(dimSizesRef);
342+
const index_type *lvlSizes = MEMREF_GET_PAYLOAD(lvlSizesRef);
343+
const DimLevelType *lvlTypes = MEMREF_GET_PAYLOAD(lvlTypesRef);
344+
const index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef);
345+
const index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef);
304346

305347
// Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
306348
// This is safe because of the static_assert above.
@@ -425,10 +467,8 @@ void *_mlir_ciface_newSparseTensor( // NOLINT
425467
assert(ref &&tensor); \
426468
std::vector<V> *v; \
427469
static_cast<SparseTensorStorageBase *>(tensor)->getValues(&v); \
428-
ref->basePtr = ref->data = v->data(); \
429-
ref->offset = 0; \
430-
ref->sizes[0] = v->size(); \
431-
ref->strides[0] = 1; \
470+
assert(v); \
471+
vectorToMemref(*v, *ref); \
432472
}
433473
MLIR_SPARSETENSOR_FOREVERY_V(IMPL_SPARSEVALUES)
434474
#undef IMPL_SPARSEVALUES
@@ -439,10 +479,8 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_SPARSEVALUES)
439479
assert(ref &&tensor); \
440480
std::vector<TYPE> *v; \
441481
static_cast<SparseTensorStorageBase *>(tensor)->LIB(&v, d); \
442-
ref->basePtr = ref->data = v->data(); \
443-
ref->offset = 0; \
444-
ref->sizes[0] = v->size(); \
445-
ref->strides[0] = 1; \
482+
assert(v); \
483+
vectorToMemref(*v, *ref); \
446484
}
447485
#define IMPL_SPARSEPOINTERS(PNAME, P) \
448486
IMPL_GETOVERHEAD(sparsePointers##PNAME, P, getPointers)
@@ -463,16 +501,17 @@ MLIR_SPARSETENSOR_FOREVERY_O(IMPL_SPARSEINDICES)
463501
void *lvlCOO, StridedMemRefType<V, 0> *vref, \
464502
StridedMemRefType<index_type, 1> *dimIndRef, \
465503
StridedMemRefType<index_type, 1> *dim2lvlRef) { \
466-
assert(lvlCOO &&vref &&dimIndRef &&dim2lvlRef); \
467-
assert(dimIndRef->strides[0] == 1 && dim2lvlRef->strides[0] == 1); \
468-
const uint64_t rank = dimIndRef->sizes[0]; \
469-
assert(dim2lvlRef->sizes[0] == (int64_t)rank); \
470-
const index_type *dimInd = dimIndRef->data + dimIndRef->offset; \
471-
const index_type *dim2lvl = dim2lvlRef->data + dim2lvlRef->offset; \
504+
assert(lvlCOO &&vref); \
505+
ASSERT_NO_STRIDE(dimIndRef); \
506+
ASSERT_NO_STRIDE(dim2lvlRef); \
507+
const uint64_t rank = MEMREF_GET_USIZE(dimIndRef); \
508+
ASSERT_USIZE_EQ(dim2lvlRef, rank); \
509+
const index_type *dimInd = MEMREF_GET_PAYLOAD(dimIndRef); \
510+
const index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef); \
472511
std::vector<index_type> lvlInd(rank); \
473512
for (uint64_t d = 0; d < rank; ++d) \
474513
lvlInd[dim2lvl[d]] = dimInd[d]; \
475-
V *value = vref->data + vref->offset; \
514+
V *value = MEMREF_GET_PAYLOAD(vref); \
476515
static_cast<SparseTensorCOO<V> *>(lvlCOO)->add(lvlInd, *value); \
477516
return lvlCOO; \
478517
}
@@ -483,11 +522,11 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_ADDELT)
483522
bool _mlir_ciface_getNext##VNAME(void *iter, \
484523
StridedMemRefType<index_type, 1> *iref, \
485524
StridedMemRefType<V, 0> *vref) { \
486-
assert(iter &&iref &&vref); \
487-
assert(iref->strides[0] == 1); \
488-
index_type *indx = iref->data + iref->offset; \
489-
V *value = vref->data + vref->offset; \
490-
const uint64_t isize = iref->sizes[0]; \
525+
assert(iter &&vref); \
526+
ASSERT_NO_STRIDE(iref); \
527+
index_type *indx = MEMREF_GET_PAYLOAD(iref); \
528+
V *value = MEMREF_GET_PAYLOAD(vref); \
529+
const uint64_t isize = MEMREF_GET_USIZE(iref); \
491530
const Element<V> *elem = \
492531
static_cast<SparseTensorIterator<V> *>(iter)->getNext(); \
493532
if (elem == nullptr) \
@@ -504,11 +543,11 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETNEXT)
504543
void _mlir_ciface_lexInsert##VNAME(void *tensor, \
505544
StridedMemRefType<index_type, 1> *cref, \
506545
StridedMemRefType<V, 0> *vref) { \
507-
assert(tensor &&cref &&vref); \
508-
assert(cref->strides[0] == 1); \
509-
index_type *cursor = cref->data + cref->offset; \
546+
assert(tensor &&vref); \
547+
ASSERT_NO_STRIDE(cref); \
548+
index_type *cursor = MEMREF_GET_PAYLOAD(cref); \
510549
assert(cursor); \
511-
V *value = vref->data + vref->offset; \
550+
V *value = MEMREF_GET_PAYLOAD(vref); \
512551
static_cast<SparseTensorStorageBase *>(tensor)->lexInsert(cursor, *value); \
513552
}
514553
MLIR_SPARSETENSOR_FOREVERY_V(IMPL_LEXINSERT)
@@ -519,16 +558,16 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_LEXINSERT)
519558
void *tensor, StridedMemRefType<index_type, 1> *cref, \
520559
StridedMemRefType<V, 1> *vref, StridedMemRefType<bool, 1> *fref, \
521560
StridedMemRefType<index_type, 1> *aref, index_type count) { \
522-
assert(tensor &&cref &&vref &&fref &&aref); \
523-
assert(cref->strides[0] == 1); \
524-
assert(vref->strides[0] == 1); \
525-
assert(fref->strides[0] == 1); \
526-
assert(aref->strides[0] == 1); \
527-
assert(vref->sizes[0] == fref->sizes[0]); \
528-
index_type *cursor = cref->data + cref->offset; \
529-
V *values = vref->data + vref->offset; \
530-
bool *filled = fref->data + fref->offset; \
531-
index_type *added = aref->data + aref->offset; \
561+
assert(tensor); \
562+
ASSERT_NO_STRIDE(cref); \
563+
ASSERT_NO_STRIDE(vref); \
564+
ASSERT_NO_STRIDE(fref); \
565+
ASSERT_NO_STRIDE(aref); \
566+
ASSERT_USIZE_EQ(vref, MEMREF_GET_USIZE(fref)); \
567+
index_type *cursor = MEMREF_GET_PAYLOAD(cref); \
568+
V *values = MEMREF_GET_PAYLOAD(vref); \
569+
bool *filled = MEMREF_GET_PAYLOAD(fref); \
570+
index_type *added = MEMREF_GET_PAYLOAD(aref); \
532571
static_cast<SparseTensorStorageBase *>(tensor)->expInsert( \
533572
cursor, values, filled, added, count); \
534573
}
@@ -537,26 +576,26 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_EXPINSERT)
537576

538577
void _mlir_ciface_getSparseTensorReaderDimSizes(
539578
void *p, StridedMemRefType<index_type, 1> *dref) {
540-
assert(p && dref);
541-
assert(dref->strides[0] == 1);
542-
index_type *dimSizes = dref->data + dref->offset;
579+
assert(p);
580+
ASSERT_NO_STRIDE(dref);
581+
index_type *dimSizes = MEMREF_GET_PAYLOAD(dref);
543582
SparseTensorReader &file = *static_cast<SparseTensorReader *>(p);
544583
const index_type *sizes = file.getDimSizes();
545584
index_type rank = file.getRank();
546-
for (uint64_t r = 0; r < rank; ++r)
585+
for (index_type r = 0; r < rank; ++r)
547586
dimSizes[r] = sizes[r];
548587
}
549588

550589
#define IMPL_GETNEXT(VNAME, V) \
551590
void _mlir_ciface_getSparseTensorReaderNext##VNAME( \
552591
void *p, StridedMemRefType<index_type, 1> *iref, \
553592
StridedMemRefType<V, 0> *vref) { \
554-
assert(p &&iref &&vref); \
555-
assert(iref->strides[0] == 1); \
556-
index_type *indices = iref->data + iref->offset; \
593+
assert(p &&vref); \
594+
ASSERT_NO_STRIDE(iref); \
595+
index_type *indices = MEMREF_GET_PAYLOAD(iref); \
557596
SparseTensorReader *stfile = static_cast<SparseTensorReader *>(p); \
558597
index_type rank = stfile->getRank(); \
559-
V *value = vref->data + vref->offset; \
598+
V *value = MEMREF_GET_PAYLOAD(vref); \
560599
*value = stfile->readCOOElement<V>(rank, indices); \
561600
}
562601
MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETNEXT)
@@ -565,10 +604,10 @@ MLIR_SPARSETENSOR_FOREVERY_V(IMPL_GETNEXT)
565604
void _mlir_ciface_outSparseTensorWriterMetaData(
566605
void *p, index_type rank, index_type nnz,
567606
StridedMemRefType<index_type, 1> *dref) {
568-
assert(p && dref);
569-
assert(dref->strides[0] == 1);
607+
assert(p);
608+
ASSERT_NO_STRIDE(dref);
570609
assert(rank != 0);
571-
index_type *dimSizes = dref->data + dref->offset;
610+
index_type *dimSizes = MEMREF_GET_PAYLOAD(dref);
572611
SparseTensorWriter &file = *static_cast<SparseTensorWriter *>(p);
573612
file << rank << " " << nnz << std::endl;
574613
for (index_type r = 0; r < rank - 1; ++r)
@@ -580,13 +619,13 @@ void _mlir_ciface_outSparseTensorWriterMetaData(
580619
void _mlir_ciface_outSparseTensorWriterNext##VNAME( \
581620
void *p, index_type rank, StridedMemRefType<index_type, 1> *iref, \
582621
StridedMemRefType<V, 0> *vref) { \
583-
assert(p &&iref &&vref); \
584-
assert(iref->strides[0] == 1); \
585-
index_type *indices = iref->data + iref->offset; \
622+
assert(p &&vref); \
623+
ASSERT_NO_STRIDE(iref); \
624+
index_type *indices = MEMREF_GET_PAYLOAD(iref); \
586625
SparseTensorWriter &file = *static_cast<SparseTensorWriter *>(p); \
587-
for (uint64_t r = 0; r < rank; ++r) \
626+
for (index_type r = 0; r < rank; ++r) \
588627
file << (indices[r] + 1) << " "; \
589-
V *value = vref->data + vref->offset; \
628+
V *value = MEMREF_GET_PAYLOAD(vref); \
590629
file << *value << std::endl; \
591630
}
592631
MLIR_SPARSETENSOR_FOREVERY_V(IMPL_OUTNEXT)
@@ -723,4 +762,9 @@ void delSparseTensorWriter(void *p) {
723762

724763
} // extern "C"
725764

765+
#undef MEMREF_GET_PAYLOAD
766+
#undef ASSERT_USIZE_EQ
767+
#undef MEMREF_GET_USIZE
768+
#undef ASSERT_NO_STRIDE
769+
726770
#endif // MLIR_CRUNNERUTILS_DEFINE_FUNCTIONS

0 commit comments

Comments
 (0)