Skip to content

Commit 915830c

Browse files
committed
addresses feedbacks
1 parent 3bb754b commit 915830c

File tree

2 files changed

+3
-10
lines changed

2 files changed

+3
-10
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -303,8 +303,7 @@ void XeGPUBlockingPass::runOnOperation() {
303303
// If the encoding is a ScatterTensorDescAttr, we need to
304304
// potentially adjust the chunk size based on the inst_data.
305305
if (tdescTy.isScattered()) {
306-
auto scatterAttr = tdescTy.getEncodingAsScatterTensorDescAttr();
307-
// mlir::dyn_cast<xegpu::ScatterTensorDescAttr>(encoding);
306+
auto scatterAttr = llvm::dyn_cast_if_present<xegpu::ScatterTensorDescAttr>(encoding);
308307
int64_t chunkSize = scatterAttr.getChunkSize().getInt();
309308

310309
if (chunkSize > 1) {
@@ -313,12 +312,9 @@ void XeGPUBlockingPass::runOnOperation() {
313312
if (!instData.empty())
314313
blockedChunkSize = instData.asArrayRef().back();
315314

316-
auto chunkSizeAttr = mlir::IntegerAttr::get(
317-
mlir::IntegerType::get(ctx, 64), blockedChunkSize);
318-
319315
// To create a new attribute with a different chunk_size:
320316
auto newEncoding = xegpu::ScatterTensorDescAttr::get(
321-
ctx, scatterAttr.getMemorySpace(), chunkSizeAttr);
317+
ctx, scatterAttr.getMemorySpace().getValue(), blockedChunkSize);
322318

323319
encoding = newEncoding;
324320
}

mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,9 @@ struct TestXeGPUUnrollingPatterns
116116
if (!instData.empty())
117117
blockedChunkSize = instData.asArrayRef().back();
118118

119-
auto chunkSizeAttr = mlir::IntegerAttr::get(
120-
mlir::IntegerType::get(ctx, 64), blockedChunkSize);
121-
122119
// To create a new attribute with a different chunk_size:
123120
auto newEncoding = xegpu::ScatterTensorDescAttr::get(
124-
ctx, scatterAttr.getMemorySpace(), chunkSizeAttr);
121+
ctx, scatterAttr.getMemorySpace().getValue(), blockedChunkSize);
125122

126123
encoding = newEncoding;
127124
}

0 commit comments

Comments
 (0)