Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -712,10 +712,14 @@ def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> {
return getAttrs().contains(name);
}

ArrayAttr getStrides() {
ArrayAttr getStrideAttr() {
return getAttrs().getAs<ArrayAttr>("stride");
}

ArrayAttr getBlockAttr() {
return getAttrs().getAs<ArrayAttr>("block");
}

}];

}
Expand Down
65 changes: 20 additions & 45 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1298,14 +1298,14 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure,
}

def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
AllElementTypesMatch<["mem_desc", "res"]>,
AllRanksMatch<["mem_desc", "res"]>]> {
AllElementTypesMatch<["mem_desc", "res"]>]> {
let arguments = (ins XeGPU_MemDesc:$mem_desc,
Variadic<Index>: $offsets,
DenseI64ArrayAttr: $const_offsets,
OptionalAttr<UnitAttr>:$subgroup_block_io,
OptionalAttr<DistributeLayoutAttr>:$layout
);
let results = (outs XeGPU_ValueType:$res);
let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$res);
let assemblyFormat = [{
$mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
prop-dict attr-dict `` `:` type(operands) `->` type(results)
Expand All @@ -1319,6 +1319,9 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
Arguments:
- `mem_desc`: the memory descriptor identifying the SLM region.
- `offsets`: the coordinates within the matrix to read from.
- `subgroup_block_io`: [optional] An attribute indicating that the operation can be
lowered to a subgroup block load. When this attribute is present,
the offsets are subgroup-uniform across all lanes.
- `layout`: [optional] An attribute for guiding distributions among
subgroups and/or work-items. It currently can accept either
LayoutAttr or SliceAttr.
Expand All @@ -1336,21 +1339,24 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
}

ArrayRef<int64_t> getDataShape() {
return getRes().getType().getShape();
auto resTy = getRes().getType();
if (auto vecTy = llvm::dyn_cast<VectorType>(resTy))
return vecTy.getShape();
return {};
}
}];

let hasVerifier = 1;
}

def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
AllElementTypesMatch<["mem_desc", "data"]>,
AllRanksMatch<["mem_desc", "data"]>]> {
AllElementTypesMatch<["mem_desc", "data"]>]> {
let arguments = (ins
XeGPU_ValueType:$data,
AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$data,
XeGPU_MemDesc:$mem_desc,
Variadic<Index>: $offsets,
DenseI64ArrayAttr: $const_offsets,
OptionalAttr<UnitAttr>:$subgroup_block_io,
OptionalAttr<DistributeLayoutAttr>:$layout
);
let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
Expand All @@ -1364,6 +1370,9 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
- `mem_desc`: the memory descriptor specifying the SLM region.
- `offsets`: the coordinates within the matrix where the data will be written.
- `data`: the values to be stored in the matrix.
- `subgroup_block_io`: [optional] An attribute indicating that the operation can be
lowered to a subgroup block store. When this attribute is present,
the offsets are subgroup-uniform across all lanes.
- `layout`: [optional] An attribute for guiding distributions among
subgroups and/or work-items. It currently can accept either
LayoutAttr or SliceAttr.
Expand All @@ -1378,49 +1387,15 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
}

ArrayRef<int64_t> getDataShape() {
return getData().getType().getShape();
auto DataTy = getData().getType();
if (auto vecTy = llvm::dyn_cast<VectorType>(DataTy))
return vecTy.getShape();
return {};
}

}];

let hasVerifier = 1;
}

def XeGPU_MemDescSubviewOp: XeGPU_Op<"mem_desc_subview",
[Pure, ViewLikeOpInterface, AllElementTypesMatch<["src", "res"]>]> {
let description = [{
Creates a subview of a memory descriptor. The resulting memory descriptor can have
a lower rank than the source; in this case, the result dimensions correspond to the
higher-order dimensions of the source memory descriptor.

Arguments:
- `src` : a memory descriptor.
- `offsets` : the coordinates within the matrix the subview will be created from.

Results:
- `res` : a memory descriptor with smaller size.

}];
let arguments = (ins XeGPU_MemDesc:$src,
Variadic<Index>:$offsets,
DenseI64ArrayAttr:$const_offsets);
let results = (outs XeGPU_MemDesc:$res);
let assemblyFormat = [{$src `` custom<DynamicIndexList>($offsets, $const_offsets) prop-dict
attr-dict `` `:` qualified(type($src)) `->` qualified(type($res))}];
let builders = [
OpBuilder<(ins "Type": $res, "Value":$src, "llvm::ArrayRef<OpFoldResult>": $offsets)>
];

let extraClassDeclaration = [{
mlir::Value getViewSource() { return getSrc(); }

SmallVector<OpFoldResult> getMixedOffsets() {
return getMixedValues(getConstOffsets(), getOffsets(), getContext());
}
}];

let hasVerifier = 1;
}


#endif // MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD
62 changes: 59 additions & 3 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -237,19 +237,75 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
return MemDescType::get(getContext(), shape.value_or(getShape()), elementType, getMemLayout());
}

ArrayAttr getStrides() {
ArrayAttr getStrideAttr() {
auto layout = getMemLayout();
if (layout && layout.hasAttr("stride")) {
return layout.getStrides();
return layout.getStrideAttr();
}

// derive and return default strides
SmallVector<int64_t> defaultStrides;
llvm::append_range(defaultStrides, getShape().drop_front());
llvm::append_values(defaultStrides, 1);
Builder builder(getContext());
return builder.getI64ArrayAttr(defaultStrides);
}

ArrayAttr getBlockAttr() {
auto layout = getMemLayout();
if (layout && layout.hasAttr("block")) {
return layout.getBlockAttr();
}
Builder builder(getContext());
return builder.getI64ArrayAttr({});
}

/// Heuristic to determine if the MemDesc uses column-major layout,
/// based on the rank and the value of the first stride dimension.
bool isColMajor() {
auto dim0 = dyn_cast<IntegerAttr>(getStrideAttr()[0]);
return getRank() == 2 && dim0.getInt() == 1;
}

// Get the Blocking shape for a MemDescType, Which is represented
// as an attribute in MemDescType. By default it is the shape
// of the mdescTy
SmallVector<int64_t> getBlockShape() {
SmallVector<int64_t> size(getShape());
ArrayAttr blockAttr = getBlockAttr();
if (!blockAttr.empty()) {
size.clear();
for (auto attr : blockAttr.getValue()) {
size.push_back(cast<IntegerAttr>(attr).getInt());
}
}
return size;
}

// Get strides as vector of integer.
// If it contains block attribute, the strides are blocked strides.
//
// The blocking is applied to the base matrix shape derived from the
// memory descriptor's stride information. If the matrix described by
// the memory descriptor is not contiguous, it is assumed that the base
// matrix is contiguous and follows the same memory layout.
//
// It first computes the original matrix shape using the stride info,
// then computes the number of blocks in each dimension of original shape,
// then compute the outer block shape and stride,
// then combines the inner and outer block shape and stride
// e.g. for `mem_desc<32x256xf16, @block=[16, 8], @strides=[1, 32]>`
// its memory layout tuple is ([2,32,16,8],[128,256,1,16])
// for `mem_desc<256x32xf16, @block=[8, 16]>` with default @stride[32, 1]
// its memory layout tuple is ([32,2,8,16],[256,128,16,1])
SmallVector<int64_t> getStrideShape();

/// Generates instructions to compute the linearize offset
// if the memory descriptor is blocked, it returns linearize offset based on the blocked layout
// the strides of memory descriptor is always considered regardless of blocked or not
Value getLinearOffsets(OpBuilder &builder,
Location loc, ArrayRef<OpFoldResult> offsets);


}];

let hasCustomAssemblyFormat = true;
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ add_mlir_conversion_library(MLIRXeGPUToXeVM
MLIRIndexDialect
MLIRSCFDialect
MLIRXeGPUDialect
MLIRXeGPUUtils
MLIRPass
MLIRTransforms
MLIRSCFTransforms
Expand Down
Loading