Skip to content

Commit 7a63d93

Browse files
committed
address more feedback
1 parent 2d73e04 commit 7a63d93

File tree

8 files changed

+6
-139
lines changed

8 files changed

+6
-139
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -712,7 +712,7 @@ def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> {
712712
return getAttrs().contains(name);
713713
}
714714

715-
ArrayAttr getStrides() {
715+
ArrayAttr getStrideAttr() {
716716
return getAttrs().getAs<ArrayAttr>("stride");
717717
}
718718

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1392,41 +1392,4 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
13921392
let hasVerifier = 1;
13931393
}
13941394

1395-
def XeGPU_MemDescSubviewOp: XeGPU_Op<"mem_desc_subview",
1396-
[Pure, ViewLikeOpInterface, AllElementTypesMatch<["src", "res"]>]> {
1397-
let description = [{
1398-
Creates a subview of a memory descriptor. The resulting memory descriptor can have
1399-
a lower rank than the source; in this case, the result dimensions correspond to the
1400-
higher-order dimensions of the source memory descriptor.
1401-
1402-
Arguments:
1403-
- `src` : a memory descriptor.
1404-
- `offsets` : the coordinates within the matrix the subview will be created from.
1405-
1406-
Results:
1407-
- `res` : a memory descriptor with smaller size.
1408-
1409-
}];
1410-
let arguments = (ins XeGPU_MemDesc:$src,
1411-
Variadic<Index>:$offsets,
1412-
DenseI64ArrayAttr:$const_offsets);
1413-
let results = (outs XeGPU_MemDesc:$res);
1414-
let assemblyFormat = [{$src `` custom<DynamicIndexList>($offsets, $const_offsets) prop-dict
1415-
attr-dict `` `:` qualified(type($src)) `->` qualified(type($res))}];
1416-
let builders = [
1417-
OpBuilder<(ins "Type": $res, "Value":$src, "llvm::ArrayRef<OpFoldResult>": $offsets)>
1418-
];
1419-
1420-
let extraClassDeclaration = [{
1421-
mlir::Value getViewSource() { return getSrc(); }
1422-
1423-
SmallVector<OpFoldResult> getMixedOffsets() {
1424-
return getMixedValues(getConstOffsets(), getOffsets(), getContext());
1425-
}
1426-
}];
1427-
1428-
let hasVerifier = 1;
1429-
}
1430-
1431-
14321395
#endif // MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -237,10 +237,10 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
237237
return MemDescType::get(getContext(), shape.value_or(getShape()), elementType, getMemLayout());
238238
}
239239

240-
ArrayAttr getStridesAttr() {
240+
ArrayAttr getStrideAttr() {
241241
auto layout = getMemLayout();
242242
if (layout && layout.hasAttr("stride")) {
243-
return layout.getStrides();
243+
return layout.getStrideAttr();
244244
}
245245
// derive and return default strides
246246
SmallVector<int64_t> defaultStrides;
@@ -262,7 +262,7 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
262262
/// Heuristic to determine if the MemDesc uses column-major layout,
263263
/// based on the rank and the value of the first stride dimension.
264264
bool isColMajor() {
265-
auto dim0 = dyn_cast<IntegerAttr>(getStridesAttr()[0]);
265+
auto dim0 = dyn_cast<IntegerAttr>(getStrideAttr()[0]);
266266
return getRank() == 2 && dim0 && dim0.getInt() == 1;
267267
}
268268

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -534,18 +534,6 @@ class CreateMemDescOpPattern final
534534
}
535535
};
536536

537-
class MemDescSubviewOpPattern final
538-
: public OpConversionPattern<xegpu::MemDescSubviewOp> {
539-
public:
540-
using OpConversionPattern<xegpu::MemDescSubviewOp>::OpConversionPattern;
541-
LogicalResult
542-
matchAndRewrite(xegpu::MemDescSubviewOp op, OpAdaptor adaptor,
543-
ConversionPatternRewriter &rewriter) const override {
544-
return rewriter.notifyMatchFailure(
545-
op, "MemDescSubviewOp are not supported on Xe2/Xe3 architecture.");
546-
}
547-
};
548-
549537
template <typename OpType,
550538
typename = std::enable_if_t<llvm::is_one_of<
551539
OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
@@ -1085,8 +1073,7 @@ void mlir::populateXeGPUToXeVMConversionPatterns(
10851073
typeConverter, patterns.getContext());
10861074
patterns.add<LoadStoreMatrixToXeVMPattern<xegpu::LoadMatrixOp>,
10871075
LoadStoreMatrixToXeVMPattern<xegpu::StoreMatrixOp>,
1088-
CreateMemDescOpPattern, MemDescSubviewOpPattern>(
1089-
typeConverter, patterns.getContext());
1076+
CreateMemDescOpPattern>(typeConverter, patterns.getContext());
10901077
patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
10911078
patterns.getContext());
10921079
}

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -781,7 +781,7 @@ SmallVector<int64_t> MemDescType::getStrideShape() {
781781

782782
SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
783783

784-
ArrayAttr strideAttr = getStridesAttr();
784+
ArrayAttr strideAttr = getStrideAttr();
785785
SmallVector<int64_t> strides;
786786
for (Attribute attr : strideAttr.getValue()) {
787787
strides.push_back(cast<IntegerAttr>(attr).getInt());

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,40 +1135,6 @@ LogicalResult StoreMatrixOp::verify() {
11351135
[&]() { return emitError(); });
11361136
}
11371137

1138-
//===----------------------------------------------------------------------===//
1139-
// XeGPU_MemDescSubviewOp
1140-
//===----------------------------------------------------------------------===//
1141-
1142-
void MemDescSubviewOp::build(OpBuilder &builder, OperationState &state,
1143-
Type resTy, Value src,
1144-
llvm::ArrayRef<OpFoldResult> offsets) {
1145-
llvm::SmallVector<Value> dynamicOffsets;
1146-
llvm::SmallVector<int64_t> staticOffsets;
1147-
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
1148-
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
1149-
build(builder, state, resTy, src, dynamicOffsets, staticOffsetsAttr);
1150-
}
1151-
1152-
LogicalResult MemDescSubviewOp::verify() {
1153-
MemDescType srcTy = getSrc().getType();
1154-
MemDescType resTy = getRes().getType();
1155-
ArrayRef<int64_t> srcShape = srcTy.getShape();
1156-
ArrayRef<int64_t> resShape = resTy.getShape();
1157-
1158-
if (srcTy.getRank() < resTy.getRank())
1159-
return emitOpError("result rank must not exceed source rank.");
1160-
1161-
if (llvm::any_of(
1162-
llvm::zip_equal(resShape, srcShape.take_back(resShape.size())),
1163-
[](auto p) { return std::get<0>(p) > std::get<1>(p); }))
1164-
return emitOpError("result shape must not exceed source shape.");
1165-
1166-
if (srcTy.getStridesAttr() != resTy.getStridesAttr())
1167-
return emitOpError("result must inherit the source strides.");
1168-
1169-
return success();
1170-
}
1171-
11721138
} // namespace xegpu
11731139
} // namespace mlir
11741140

mlir/test/Dialect/XeGPU/invalid.mlir

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -913,31 +913,3 @@ func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data
913913
return
914914
}
915915

916-
// -----
917-
func.func @mem_desc_subview_size_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) {
918-
// expected-error@+1 {{result shape must not exceed source shape}}
919-
%data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<32x16xf16>
920-
return
921-
}
922-
923-
// -----
924-
func.func @mem_desc_subview_layout_mismatch(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride =[1, 16]>>) {
925-
// expected-error@+1 {{result must inherit the source strides}}
926-
%data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride =[1, 16]>> -> !xegpu.mem_desc<8x16xf16>
927-
return
928-
}
929-
930-
// -----
931-
func.func @mem_desc_subview_element_type_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) {
932-
// expected-error@+1 {{failed to verify that all of {src, res} have same element type}}
933-
%data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf32, #xegpu.mem_layout<stride =[64, 1]>>
934-
return
935-
}
936-
937-
// -----
938-
func.func @mem_desc_subview_rank_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) {
939-
// expected-error@+1 {{result rank must not exceed source rank}}
940-
%data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<4x8x16xf16>
941-
return
942-
}
943-

mlir/test/Dialect/XeGPU/ops.mlir

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -895,25 +895,4 @@ gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_
895895
gpu.return
896896
}
897897

898-
// CHECK: gpu.func @mem_desc_subview([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
899-
gpu.func @mem_desc_subview(%arg0: !xegpu.mem_desc<16x64xf16>) {
900-
//CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [64, 1]>>
901-
%data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [64, 1]>>
902-
gpu.return
903-
}
904-
905-
// CHECK: gpu.func @mem_desc_subview_lower_rank([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
906-
gpu.func @mem_desc_subview_lower_rank(%arg0: !xegpu.mem_desc<16x64xf16>) {
907-
//CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<16xf16, #xegpu.mem_layout<stride = [64, 1]>>
908-
%data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<16xf16, #xegpu.mem_layout<stride = [64, 1]>>
909-
gpu.return
910-
}
911-
912-
// CHECK: gpu.func @mem_desc_subview_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>)
913-
gpu.func @mem_desc_subview_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) {
914-
//CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [1, 16]>>
915-
%data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [1, 16]>>
916-
gpu.return
917-
}
918-
919898
}

0 commit comments

Comments
 (0)