Skip to content

Commit 9728975

Browse files
authored
Add support for non memref source for init tile in wg to sg pass (#1060)
Add support for non memref source for init tile
1 parent 6d573d1 commit 9728975

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

lib/Dialect/XeTile/Transforms/WgToSg.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -208,20 +208,18 @@ class WGToSGInitTileOpPattern : public OpConversionPattern<xetile::InitTileOp> {
208208
for (size_t j = 0; j < offsets.size() - 2; ++j) {
209209
newOffsets.push_back(offsets[j]);
210210
}
211+
212+
auto sourceMemRefType = mlir::dyn_cast<mlir::MemRefType>(source.getType());
213+
211214
for (size_t i = 0; i < offsetPermutations.size(); i++) {
212215
newOffsets.push_back(offsetPermutations[i][0]);
213216
newOffsets.push_back(offsetPermutations[i][1]);
214217
Value newInitTileOp = nullptr;
215-
auto sourceMemRefType = mlir::dyn_cast<mlir::MemRefType>(source.getType());
216-
if (!sourceMemRefType) {
217-
return failure();
218-
}
219-
220-
if (sourceMemRefType.hasStaticShape()) {
218+
if (sourceMemRefType && sourceMemRefType.hasStaticShape()) {
221219
newInitTileOp = rewriter.create<xetile::InitTileOp>(
222220
loc, newTileTy, source, newOffsets);
223221
}
224-
else {
222+
else { // memref with dynamic shape or non memref source
225223
newInitTileOp = rewriter.create<xetile::InitTileOp>(
226224
loc, newTileTy, source, newOffsets, op.getMixedSizes(), op.getMixedStrides());
227225
}

test/Dialect/XeTile/Transforms/WgToSg/unit_tests.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,4 +160,12 @@ gpu.module @test_arith_extf {
160160
}
161161
gpu.return
162162
}
163+
164+
gpu.func @test_init_tile_using_addr(%src: i64, %dim0_size : index, %dim1_size : index,
165+
%dim0_stride : index, %dim1_stride : index ) {
166+
//CHECK: xetile.init_tile {{%.*}}[{{%.*}}, {{%.*}}], [{{%.*}}, {{%.*}}], [{{%.*}}, {{%.*}}] : i64 -> !xetile.tile<4x64xf16>
167+
%1 = xetile.init_tile %src[8, 16], [%dim0_size, %dim1_size], [%dim0_stride, %dim1_stride]
168+
: i64 -> !xetile.tile<128x64xf16, #xetile.tile_attr<wg_map = <sg_layout = [32, 1], sg_data = [4, 64]>>>
169+
gpu.return
170+
}
163171
}

0 commit comments

Comments
 (0)