Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -716,8 +716,30 @@ def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> {
return getAttrs().getAs<ArrayAttr>("stride");
}

ArrayAttr getBlockAttr() {
return getAttrs().getAs<ArrayAttr>("block");
}

}];

}

def RowOriented : I32EnumAttrCase<"ROW", 0, "row">;
def ColOriented : I32EnumAttrCase<"COL", 1, "col">;
def MatrixAccessDirection :
I32EnumAttr<"MatrixAccessDirection",
"Matrix elements/vectors can have row or column direction", [
RowOriented, ColOriented
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::xegpu";
}
def MatrixAccessDirectionAttr :
EnumAttr<XeGPU_Dialect,
MatrixAccessDirection,
"matrix_access_direction">{
let summary = [{Describe the direction of memory access for load_matrix and store_matrix.}];
let assemblyFormat = "`<` $value `>`";
}

#endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD
26 changes: 18 additions & 8 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1298,14 +1298,16 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure,
}

def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
AllElementTypesMatch<["mem_desc", "res"]>,
AllRanksMatch<["mem_desc", "res"]>]> {
AllElementTypesMatch<["mem_desc", "res"]>]> {
let arguments = (ins XeGPU_MemDesc:$mem_desc,
Variadic<Index>: $offsets,
DenseI64ArrayAttr: $const_offsets,
OptionalAttr<I32Attr>:$vec_length,
OptionalAttr<MatrixAccessDirectionAttr>:$vec_direction,
OptionalAttr<UnitAttr>:$subgroup_block_io,
OptionalAttr<DistributeLayoutAttr>:$layout
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please update the descrioption of the op with the meaning of block_io

);
let results = (outs XeGPU_ValueType:$res);
let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$res);
let assemblyFormat = [{
$mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
prop-dict attr-dict `` `:` type(operands) `->` type(results)
Expand Down Expand Up @@ -1336,21 +1338,26 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
}

ArrayRef<int64_t> getDataShape() {
return getRes().getType().getShape();
auto resTy = getRes().getType();
if (auto vecTy = llvm::dyn_cast<VectorType>(resTy))
return vecTy.getShape();
return {};
}
}];

let hasVerifier = 1;
}

def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
AllElementTypesMatch<["mem_desc", "data"]>,
AllRanksMatch<["mem_desc", "data"]>]> {
AllElementTypesMatch<["mem_desc", "data"]>]> {
let arguments = (ins
XeGPU_ValueType:$data,
AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$data,
XeGPU_MemDesc:$mem_desc,
Variadic<Index>: $offsets,
DenseI64ArrayAttr: $const_offsets,
OptionalAttr<I32Attr>:$vec_length,
OptionalAttr<MatrixAccessDirectionAttr>:$vec_direction,
OptionalAttr<UnitAttr>:$subgroup_block_io,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update description.

OptionalAttr<DistributeLayoutAttr>:$layout
);
let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
Expand Down Expand Up @@ -1378,7 +1385,10 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
}

ArrayRef<int64_t> getDataShape() {
return getData().getType().getShape();
auto DataTy = getData().getType();
if (auto vecTy = llvm::dyn_cast<VectorType>(DataTy))
return vecTy.getShape();
return {};
}

}];
Expand Down
50 changes: 49 additions & 1 deletion mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
return MemDescType::get(getContext(), shape.value_or(getShape()), elementType, getMemLayout());
}

ArrayAttr getStrides() {
ArrayAttr getStridesAttr() {
auto layout = getMemLayout();
if (layout && layout.hasAttr("stride")) {
return layout.getStrides();
Expand All @@ -250,6 +250,54 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
Builder builder(getContext());
return builder.getI64ArrayAttr(defaultStrides);
}

/// Heuristic to determine if the MemDesc uses column-major layout,
/// based on the rank and the value of the first stride dimension.
bool isColMajor() {
auto dim0 = dyn_cast<IntegerAttr>(getStridesAttr()[0]);
return getRank() == 2 && dim0 && dim0.getInt() == 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why check dim0 exists? it should always exist right?

}

// get the Blocking shape for a MemDescType, Which is represented
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Capitalize the first letter of all comment sentences per coding standards.

// as an attribute in MemDescType. By default it is the shape
// of the mdescTy
SmallVector<int64_t> getBlockSize() {
SmallVector<int64_t> size(getShape());
MemLayoutAttr layout = getMemLayout();
if (layout && layout.hasAttr("block")) {
ArrayAttr attr = layout.getBlockAttr();
size.clear();
llvm::for_each(attr, [&](Attribute elem) {
if (auto intElem = dyn_cast<IntegerAttr>(elem))
size.push_back(intElem.getInt());
});
}
return size;
}

// Get strides as vector of integer.
// If it contains block attribute, the strides are blocked strides.
//
// The blocking is applied against the original matrix shape
// so that the linear offset is not impacted by the subview.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the subview you refer here? is it the subview to specific block?

//
// It first computes the original matrix shape using the stride info,
// then computes the number of blocks in each dimension of original shape,
// then compute the outer block shape and stride,
// then combines the inner and outer block shape and stride
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use code quotes for (mem_desc) for code examples. That way doxygen will generate more readable docs.

// e.g. for mem_desc<32x256xf16, @block=[16, 8], @strides=[1, 32]>
// its memory layout tuple is ([2,32,16,8],[128,256,1,16])
// for mem_desc<256x32xf16, @block=[8, 16]> with default @stride[32, 1]
// its memory layout tuple is ([32,2,8,16],[256,128,16,1])
SmallVector<int64_t> getStrides();

/// Generates instructions to compute the linearize offset
// if the memory descriptor is blocked, it returns linearize offset based on the blocked layout
// the strides of memory descriptor is always considered regardless of blocked or not
Value getLinearOffsets(OpBuilder &builder,
Location loc, ArrayRef<OpFoldResult> offsets);


}];

let hasCustomAssemblyFormat = true;
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ add_mlir_conversion_library(MLIRXeGPUToXeVM
MLIRIndexDialect
MLIRSCFDialect
MLIRXeGPUDialect
MLIRXeGPUUtils
MLIRPass
MLIRTransforms
MLIRSCFTransforms
Expand Down
Loading
Loading