Skip to content

Commit c9a4ecc

Browse files
authored
Clean up wg_data attribute and modify checks for innerblock (#834)
Clean up wg_data Attribute and modify checks for innerblock
1 parent 4d18adb commit c9a4ecc

File tree

11 files changed

+16
-38
lines changed

11 files changed

+16
-38
lines changed

include/imex/Dialect/XeTile/IR/XeTileAttrs.td

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ def XeTile_TileAttr : XeTile_Attr<"XeTile", "tile_attr"> {
6464
OptionalParameter<"xetile::WorkGroupMapAttr">:$wg_map,
6565
DefaultValuedParameter<"mlir::DenseI32ArrayAttr", "mlir::DenseI32ArrayAttr::get($_ctxt, {1, 0})">:$order,
6666
OptionalParameter<"mlir::DenseI64ArrayAttr">:$inner_blocks,
67-
OptionalParameter<"mlir::DenseI32ArrayAttr">:$wg_data,
6867
OptionalParameter<"mlir::Attribute">:$memory_scope
6968
);
7069
let assemblyFormat = "`<` struct(params) `>`";
@@ -74,13 +73,11 @@ def XeTile_TileAttr : XeTile_Attr<"XeTile", "tile_attr"> {
7473
CArg<"xetile::WorkGroupMapAttr", "{}">:$wg_map,
7574
CArg<"llvm::ArrayRef<int32_t>", "{1, 0}">:$order,
7675
CArg<"llvm::ArrayRef<int64_t>", "{}">:$inner_blocks,
77-
CArg<"llvm::ArrayRef<int32_t>", "{}">:$wg_data,
7876
CArg<"int", "0">:$memory_scope),
7977
[{
8078
mlir::Type intType = mlir::IntegerType::get($_ctxt, 32);
8179
return $_get($_ctxt, sg_map, wg_map, mlir::DenseI32ArrayAttr::get($_ctxt, order),
8280
mlir::DenseI64ArrayAttr::get($_ctxt, inner_blocks),
83-
mlir::DenseI32ArrayAttr::get($_ctxt, wg_data),
8481
mlir::IntegerAttr::get(intType, memory_scope));
8582
}]>,
8683
AttrBuilder<(ins CArg<"llvm::ArrayRef<int32_t>", "{1, 0}">:$order,
@@ -90,19 +87,16 @@ def XeTile_TileAttr : XeTile_Attr<"XeTile", "tile_attr"> {
9087
return $_get($_ctxt, xetile::SubGroupMapAttr(), xetile::WorkGroupMapAttr(),
9188
mlir::DenseI32ArrayAttr::get($_ctxt, order),
9289
mlir::DenseI64ArrayAttr::get($_ctxt, {}),
93-
mlir::DenseI32ArrayAttr::get($_ctxt, {}),
9490
mlir::IntegerAttr::get(intType, memory_scope));
9591
}]>,
9692
AttrBuilder<(ins CArg<"xetile::SubGroupMapAttr", "{}">:$sg_map,
9793
CArg<"xetile::WorkGroupMapAttr", "{}">:$wg_map,
9894
CArg<"llvm::ArrayRef<int32_t>", "{1, 0}">:$order,
99-
CArg<"llvm::ArrayRef<int32_t>", "{}">:$wg_data,
10095
CArg<"int", "0">:$memory_scope),
10196
[{
10297
mlir::Type intType = mlir::IntegerType::get($_ctxt, 32);
10398
return $_get($_ctxt, sg_map, wg_map, mlir::DenseI32ArrayAttr::get($_ctxt, order),
10499
mlir::DenseI64ArrayAttr::get($_ctxt, {}),
105-
mlir::DenseI32ArrayAttr::get($_ctxt, wg_data),
106100
mlir::IntegerAttr::get(intType, memory_scope));
107101
}]>
108102
];

include/imex/Dialect/XeTile/IR/XeTileTypes.td

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,6 @@ def XeTile : XeTile_Type<"Tile", "tile", [ShapedTypeInterface],
103103
return xetile::WorkGroupMapAttr();
104104
}
105105

106-
mlir::DenseI32ArrayAttr getWgData() {
107-
auto encoding = llvm::dyn_cast_if_present<xetile::XeTileAttr>(getEncoding());
108-
if (encoding)
109-
return encoding.getWgData();
110-
return mlir::DenseI32ArrayAttr();
111-
}
112-
113106
mlir::DenseI64ArrayAttr getInnerBlocks() {
114107
auto encoding = llvm::dyn_cast_if_present<xetile::XeTileAttr>(getEncoding());
115108
if (encoding)

lib/Dialect/XeTile/IR/XeTileDialect.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,12 @@ mlir::LogicalResult XeTileAttr::verify(
116116
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
117117
::imex::xetile::SubGroupMapAttr sg_map, xetile::WorkGroupMapAttr wg_map,
118118
mlir::DenseI32ArrayAttr order, mlir::DenseI64ArrayAttr inner_blocks,
119-
mlir::DenseI32ArrayAttr wg_data, mlir::Attribute memoryScope) {
119+
mlir::Attribute memoryScope) {
120120

121121
if (order != mlir::DenseI32ArrayAttr() && order.size() != 2)
122122
emitError() << "expect integer array of size 2 for order";
123-
if (inner_blocks != mlir::DenseI64ArrayAttr() && inner_blocks.size() != 2)
124-
emitError() << "expect integer array of size 2 for inner_blocks";
125-
if (wg_data != mlir::DenseI32ArrayAttr() && wg_data.size() != 2)
126-
emitError() << "expect integer array of size 2 for wg_data";
123+
if (inner_blocks != mlir::DenseI64ArrayAttr() && (inner_blocks.size() > 0 && inner_blocks.size() != 2))
124+
emitError() << "expect integer array of size 2 for non empty inner_blocks attribute";
127125
return mlir::success();
128126
}
129127

lib/Dialect/XeTile/IR/XeTileOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ mlir::LogicalResult LoadTileOp::verify() {
386386
if (!vecShape.equals(tileShape))
387387
return emitOpError("Output shape must match the tile shape.");
388388

389-
if (innerBlocks != mlir::DenseI64ArrayAttr()) {
389+
if (innerBlocks != mlir::DenseI64ArrayAttr() && innerBlocks.size() > 0) {
390390
// if inner_blocks is present in the tile_attr, the output of the load
391391
// must be 4D
392392
if (vecShape.size() != 4)
@@ -421,7 +421,7 @@ mlir::LogicalResult StoreTileOp::verify() {
421421
return emitOpError(
422422
"value must be a 2D vector if inner_blocks is not used in tile_attr.");
423423

424-
if (innerBlocks != mlir::DenseI32ArrayAttr()) {
424+
if (innerBlocks != mlir::DenseI32ArrayAttr() && innerBlocks.size() > 0) {
425425
auto vecShape = getValue().getType().getShape();
426426
// if inner_blocks is present in the tile_attr, the stored value
427427
// must be 4D

lib/Dialect/XeTile/Transforms/BlockAligning.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,7 @@ struct InitTileOpPattern
205205

206206
auto attr = imex::xetile::XeTileAttr::get(
207207
op.getContext(), tileTy.getSgMap(), tileTy.getWgMap(),
208-
tileTy.getOrder(), newBlockSize, tileTy.getWgData(),
209-
tileTy.getMemoryScope());
208+
tileTy.getOrder(), newBlockSize, tileTy.getMemoryScope());
210209

211210
auto newTileTy = imex::xetile::TileType::get(tileTy.getShape(),
212211
tileTy.getElementType(), attr);

lib/Dialect/XeTile/Transforms/Blocking.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -848,8 +848,7 @@ struct InitTileOpPattern
848848

849849
auto attr = imex::xetile::XeTileAttr::get(
850850
op.getContext(), tileTy.getSgMap(), tileTy.getWgMap(),
851-
tileTy.getOrder(), innerBlocks, tileTy.getWgData(),
852-
tileTy.getMemoryScope());
851+
tileTy.getOrder(), innerBlocks, tileTy.getMemoryScope());
853852

854853
auto newTileTy =
855854
imex::xetile::TileType::get(tileTy.getShape(), elemTy, attr);

lib/Dialect/XeTile/Transforms/Canonicalization.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,8 +380,7 @@ struct XeTileCanonicalizationPass final
380380
auto newAttr = imex::xetile::XeTileAttr::get(
381381
tileTy.getContext(), tileTy.getSgMap(), tileTy.getWgMap(),
382382
mlir::DenseI32ArrayAttr::get(tileTy.getContext(), {1, 0}),
383-
tileTy.getInnerBlocks(), tileTy.getWgData(),
384-
tileTy.getMemoryScope());
383+
tileTy.getInnerBlocks(), tileTy.getMemoryScope());
385384

386385
return imex::xetile::TileType::get(
387386
swapLastTwoElems(tileTy.getShape()), tileTy.getElementType(),

lib/Dialect/XeTile/Transforms/OptimizeTranspose.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ struct InitTileOpPattern final
114114
: mlir::DenseI32ArrayAttr::get(getContext(), {1, 0});
115115
auto newTileAttr = imex::xetile::XeTileAttr::get(
116116
getContext(), tileTy.getSgMap(), tileTy.getWgMap(), orderAttr,
117-
tileTy.getInnerBlocks(), tileTy.getWgData(), tileTy.getMemoryScope());
117+
tileTy.getInnerBlocks(), tileTy.getMemoryScope());
118118
auto transposedTileTy = imex::xetile::TileType::get(
119119
imex::swapLastTwoElements(initOp.getType().getShape()),
120120
initOp.getElementType(), newTileAttr);

test/Dialect/XeTile/IR/invalid.mlir

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,12 +186,10 @@ func.func @test_init_tile_with_mismatch_memory_space(%a: memref<1024x1024xf16, 3
186186
#wg_map_1 = #xetile.wg_map<sg_layout = [4], sg_data = [32, 128]>
187187
// expected-error@+1 {{expect integer array of size 2 for sg_data}}
188188
#wg_map_2 = #xetile.wg_map<sg_layout = [2, 2], sg_data = [32, 128, 32]>
189-
// expected-error@+1 {{expect integer array of size 2 for inner_blocks}}
189+
// expected-error@+1 {{expect integer array of size 2 for non empty inner_blocks attribute}}
190190
#wg_map_3 = #xetile.tile_attr<inner_blocks = [8, 16, 8]>
191191
// expected-error@+1 {{expect integer array of size 2 for order}}
192192
#wg_map_4 = #xetile.tile_attr<order = [0, 1, 2]>
193-
// expected-error@+1 {{expect integer array of size 2 for wg_data}}
194-
#wg_map_5 = #xetile.tile_attr<wg_data = [32, 64, 128]>
195193

196194

197195
// -----

test/Dialect/XeTile/IR/ops.mlir

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#tile_attr = #xetile.tile_attr<wg_map = #wg_map, sg_map = #sg_map>
1010
#tile_attr_w_inner_blocks = #xetile.tile_attr<inner_blocks = [8, 16]>
1111
#tile_attr_w_order = #xetile.tile_attr<order = [0, 1]>
12-
#tile_attr_w_wg_data = #xetile.tile_attr<wg_map = #wg_map, wg_data = [128, 128]>
1312

1413

1514
#wg_map_mma_a = #xetile.wg_map<sg_layout = [8, 4], sg_data = [32, 32]>
@@ -199,16 +198,15 @@ func.func @test_store_tile(%value1 : vector<64x32xf16>,
199198
}
200199

201200
// CHECK-LABEL: func @test_prefetch_tile({{.*}}) {
202-
func.func @test_prefetch_tile(%src: !xetile.tile<64x64xf16>, %src1: !xetile.tile<128x128xf16, #tile_attr_w_wg_data>) {
201+
func.func @test_prefetch_tile(%src: !xetile.tile<64x64xf16>, %src1: !xetile.tile<128x128xf16>) {
203202

204203
// CHECK: xetile.prefetch_tile
205204
// CHECK-SAME: !xetile.tile<64x64xf16>
206205
xetile.prefetch_tile %src : !xetile.tile<64x64xf16>
207206

208207
// CHECK: xetile.prefetch_tile
209-
// CHECK-SAME: !xetile.tile<128x128xf16, #xetile.tile_attr<wg_map = <sg_layout = [2, 2],
210-
// CHECK-SAME: sg_data = [32, 128]>, wg_data = [128, 128]>>
211-
xetile.prefetch_tile %src1 : !xetile.tile<128x128xf16, #tile_attr_w_wg_data>
208+
// CHECK-SAME: !xetile.tile<128x128xf16>
209+
xetile.prefetch_tile %src1 : !xetile.tile<128x128xf16>
212210

213211
return
214212
}

0 commit comments

Comments
 (0)