@@ -49,78 +49,15 @@ namespace gc {
4949
5050namespace {
5151
52- // TODO: move these to utils
53- template <typename T>
54- static Value createTypedVector (PatternRewriter &rewriter, Location loc,
55- ArrayRef<T> values, Type elementType) {
56- mlir::VectorType vectorType =
57- mlir::VectorType::get ({static_cast <int64_t >(values.size ())}, elementType);
58- mlir::DenseElementsAttr denseAttr =
59- mlir::DenseElementsAttr::get (vectorType, values);
60- auto vector =
61- rewriter.create <mlir::arith::ConstantOp>(loc, vectorType, denseAttr)
62- .getResult ();
63- return vector;
64- }
65-
66- static Value createIndexVector (PatternRewriter &rewriter, Location loc,
67- ArrayRef<int64_t > values) {
68- return createTypedVector (rewriter, loc, values, rewriter.getIndexType ());
69- }
70-
71- static Value createIndexConstant (PatternRewriter &rewriter, Location loc,
72- int64_t value) {
73- return rewriter.create <arith::ConstantIndexOp>(loc, value);
74- }
75-
76- static Value flattenMemref (PatternRewriter &rewriter, Location loc,
77- Value srcMemref) {
78- auto srcType = cast<MemRefType>(srcMemref.getType ());
79-
80- assert (srcType && " Expected a memref type" );
81- assert (srcType.getRank () == 2 && " Expected a 2D memref" );
82-
83- int64_t flatSize = srcType.getShape ()[0 ] * srcType.getShape ()[1 ];
84-
85- Value offset = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
86- Value size = rewriter.create <arith::ConstantIndexOp>(loc, flatSize);
87- Value stride = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
88-
89- // Use memref.reinterpret_cast to flatten the memref
90- auto flatMemRefType = MemRefType::get ({flatSize}, srcType.getElementType (),
91- nullptr , srcType.getMemorySpace ());
92- auto flatMemref =
93- rewriter
94- .create <memref::ReinterpretCastOp>(loc, flatMemRefType, srcMemref,
95- offset, size, stride)
96- .getResult ();
97- return flatMemref;
98- }
99-
100- static bool hasSharedMemSpace (mlir::Value memref) {
101- auto type = mlir::dyn_cast<mlir::MemRefType>(memref.getType ());
102- if (!type)
103- return false ;
104-
105- auto memSpace = type.getMemorySpace ();
106- if (!memSpace)
107- return false ;
108-
109- if (auto gpuAttr = mlir::dyn_cast<mlir::gpu::AddressSpaceAttr>(memSpace))
110- return gpuAttr.getValue () == mlir::gpu::AddressSpace::Private;
111-
112- if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(memSpace))
113- return intAttr.getValue () ==
114- static_cast <int64_t >(mlir::gpu::AddressSpace::Private);
115-
116- return false ;
117- }
118-
11952static Value createFullMask (PatternRewriter &rewriter, Location loc,
12053 int64_t size) {
121- auto maskVal = createIndexConstant ( rewriter, loc, 32 );
54+ auto maskVal = rewriter. create <arith::ConstantIndexOp>( loc, 32 );
12255 mlir::VectorType maskVectorType =
12356 mlir::VectorType::get ({size}, rewriter.getI1Type ());
57+ // HACK: creating mask vector through this strange op instead of
58+ // simple 'arith.constant dense<true>' to avoid the mask being
59+ // moved out of the GPU kernel (it causes strange behaviour
60+ // when a bit-mask is passed as a kernel parameter).
12461 auto res = rewriter.create <vector::CreateMaskOp>(
12562 loc, maskVectorType, SmallVector<Value>({maskVal}));
12663 return res.getResult ();
@@ -826,8 +763,9 @@ static SmallVector<Value> createScatterDescriptorTiles(
826763 }
827764
828765 int64_t skipPerLoad = memrefStrides[0 ] * rowsPerLoad;
829- auto offsetPerLoad =
830- createIndexVector (rewriter, loc, SmallVector<int64_t >(32 , skipPerLoad));
766+ auto offsetPerLoad = utils::createTypedVector<int64_t >(
767+ rewriter, loc, SmallVector<int64_t >(32 , skipPerLoad),
768+ rewriter.getIndexType ());
831769
832770 auto offsetVecType = VectorType::get ({maxLoadSize}, rewriter.getIndexType ());
833771 auto descType = getTensorDescType (
@@ -843,7 +781,8 @@ static SmallVector<Value> createScatterDescriptorTiles(
843781
844782 SmallVector<Value> tiles;
845783 for (int i = 0 ; i < numColTiles; i++) {
846- auto offsetsShift = createIndexVector (rewriter, loc, offsetShiftValues[i]);
784+ auto offsetsShift = utils::createTypedVector<int64_t >(
785+ rewriter, loc, offsetShiftValues[i], rewriter.getIndexType ());
847786 auto offsets0 =
848787 rewriter.create <arith::AddIOp>(loc, blockOffsetV, offsetsShift);
849788
@@ -924,7 +863,7 @@ static SmallVector<Value> createSLMDescTiles(PatternRewriter &rewriter,
924863 }
925864
926865 // Scatter descriptors only work with 1D memrefs
927- src = flattenMemref (rewriter, loc, src);
866+ src = utils:: flattenMemref (rewriter, loc, src);
928867
929868 return createScatterDescriptorTiles (
930869 rewriter, loc, /* flatMemref=*/ src, /* loadShape2D=*/ loadShape,
@@ -938,7 +877,7 @@ static SmallVector<Value> createDescriptorTiles(
938877 std::optional<ArrayRef<int64_t >> loadOffsets = std::nullopt ,
939878 int arrayLength = 1 , bool transpose = false ) {
940879
941- if (hasSharedMemSpace (src)) {
880+ if (utils:: hasSharedMemSpace (src)) {
942881 assert (!transpose && " Transpose is not supported for shared memory" );
943882 assert (arrayLength == 1 &&
944883 " Array descriptors are not supported for shared memory" );
@@ -1092,8 +1031,8 @@ loadScatterDescTiles(PatternRewriter &rewriter, Location loc,
10921031 // Accumulator vector for the current tile (its number of elements equals to
10931032 // tileShape) HACK: we first create a flat vector of zeros and then cast it
10941033 // to the 2D shape. Otherwise 'imex::ConvertGPUXToSPIRVPass' fails.
1095- auto accumVector =
1096- createTypedVector<Attribute>( rewriter, loc, accumValues, elementType);
1034+ auto accumVector = utils::createTypedVector<Attribute>(
1035+ rewriter, loc, accumValues, elementType);
10971036 accumVector =
10981037 rewriter.create <vector::ShapeCastOp>(loc, accumVectorType, accumVector);
10991038
@@ -1961,7 +1900,7 @@ LogicalResult createMemoryFillKernel(linalg::LinalgOp linalgOp,
19611900 }
19621901
19631902 // Extract SIMD sized sub-tiles
1964- int64_t maxSizeSIMD = hasSharedMemSpace (output) ? 32 : 256 ;
1903+ int64_t maxSizeSIMD = utils:: hasSharedMemSpace (output) ? 32 : 256 ;
19651904 int64_t subTileCols = std::min (outputShape[1 ], maxSizeSIMD);
19661905 int64_t subTileRows =
19671906 std::min (outputShape[0 ], std::max (maxSizeSIMD / subTileCols, 1L ));
0 commit comments