Skip to content

Commit 281393e

Browse files
committed
address feedback
1 parent 158edd6 commit 281393e

File tree

3 files changed

+13
-8
lines changed

3 files changed

+13
-8
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1282,12 +1282,6 @@ def XeGPU_ConvertLayoutOp: XeGPU_Op<"convert_layout", [Pure, AllTypesMatch<["sou
12821282
let hasCanonicalizer = 1;
12831283
}
12841284

1285-
def isSharedPred : CPred<"isSharedMemory(llvm::cast<mlir::MemRefType>($_self))">;
1286-
class StaticShared1DMemRefOf<list<Type> allowedTypes> :
1287-
ConfinedType<MemRefRankOf<allowedTypes, [1]>, [HasStaticShapePred, isSharedPred],
1288-
"statically shaped " # MemRefOf<allowedTypes>.summary # " for shared memory",
1289-
"mlir::MemRefType">;
1290-
12911285
class SizeInBits<string name> :
12921286
StrFunc<"llvm::cast<mlir::ShapedType>($" # name # ".getType()).getNumElements()"
12931287
"*llvm::cast<mlir::ShapedType>($" # name # ".getType()).getElementTypeBitWidth()">;
@@ -1308,7 +1302,7 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure,
13081302
Results:
13091303
- `mem_desc` : the memory descriptor.
13101304
}];
1311-
let arguments = (ins AnyTypeOf<[StaticShared1DMemRefOf<[I8]>, ConfinedType<MemRefRankOf<[XeGPU_ScalarType], [2]>, [HasStaticShapePred, isSharedPred]>]>:$source);
1305+
let arguments = (ins AnyTypeOf<[StaticShared1DMemRefOf<[XeGPU_ScalarType]>, StaticShared2DMemRefOf<[XeGPU_ScalarType]>]>:$source);
13121306
let results = (outs XeGPU_MemDesc:$mem_desc);
13131307
let assemblyFormat = "$source prop-dict attr-dict `` `:` type($source) `->` qualified(type($mem_desc))";
13141308
}

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,17 @@ class XeGPUTypeDef<string name, string typeMnemonic, list<Trait> traits = [],
3535
let mnemonic = typeMnemonic;
3636
}
3737

38+
def isSharedPred : CPred<"isSharedMemory(llvm::cast<mlir::MemRefType>($_self))">;
39+
class StaticShared1DMemRefOf<list<Type> allowedTypes> :
40+
ConfinedType<MemRefRankOf<allowedTypes, [1]>, [HasStaticShapePred, isSharedPred],
41+
"reside in share memory and statically 1d shaped " # MemRefOf<allowedTypes>.summary # " ",
42+
"mlir::MemRefType">;
43+
44+
class StaticShared2DMemRefOf<list<Type> allowedTypes>:
45+
ConfinedType<MemRefRankOf<allowedTypes, [2]>, [HasStaticShapePred, isSharedPred],
46+
"reside in share memory and statically 2d shaped " # MemRefOf<allowedTypes>.summary # " ",
47+
"mlir::MemRefType">;
48+
3849
def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
3950
[ShapedTypeInterface], "::mlir::TensorType"> {
4051
let summary = "TensorDesc describing regions of interested data.";

mlir/test/Dialect/XeGPU/invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -836,7 +836,7 @@ func.func @slice_attr_repeat_dim() {
836836
// -----
837837
func.func @create_mem_desc_non_slm() {
838838
%m = memref.alloca() {alignment = 1024} : memref<2048xi8, 1>
839-
// expected-error@+1 {{operand #0 must be statically shaped memref of 8-bit signless integer values for shared memory}}
839+
// expected-error@+1 {{operand #0 must be reside in share memory and statically 1d shaped memref }}
840840
%mem_desc = xegpu.create_mem_desc %m : memref<2048xi8, 1> -> !xegpu.mem_desc<16x64xf16>
841841
return
842842
}

0 commit comments

Comments
 (0)