Skip to content

Commit 446b951

Browse files
committed
add tests and refactoring
1 parent 554b95e commit 446b951

File tree

10 files changed

+211
-77
lines changed

10 files changed

+211
-77
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1304,10 +1304,10 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
13041304
DenseI64ArrayAttr: $const_offsets,
13051305
OptionalAttr<I32Attr>:$vec_length,
13061306
OptionalAttr<MatrixAccessDirectionAttr>:$vec_direction,
1307-
OptionalAttr<UnitAttr>:$subgroupBlockIO,
1307+
OptionalAttr<UnitAttr>:$subgroup_block_io,
13081308
OptionalAttr<DistributeLayoutAttr>:$layout
13091309
);
1310-
let results = (outs XeGPU_ValueType:$res);
1310+
let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$res);
13111311
let assemblyFormat = [{
13121312
$mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
13131313
prop-dict attr-dict `` `:` type(operands) `->` type(results)
@@ -1338,7 +1338,10 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
13381338
}
13391339

13401340
ArrayRef<int64_t> getDataShape() {
1341-
return getRes().getType().getShape();
1341+
auto resTy = getRes().getType();
1342+
if (auto vecTy = llvm::dyn_cast<VectorType>(resTy))
1343+
return vecTy.getShape();
1344+
return {};
13421345
}
13431346
}];
13441347

@@ -1348,10 +1351,13 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
13481351
def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
13491352
AllElementTypesMatch<["mem_desc", "data"]>]> {
13501353
let arguments = (ins
1351-
XeGPU_ValueType:$data,
1354+
AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$data,
13521355
XeGPU_MemDesc:$mem_desc,
13531356
Variadic<Index>: $offsets,
13541357
DenseI64ArrayAttr: $const_offsets,
1358+
OptionalAttr<I32Attr>:$vec_length,
1359+
OptionalAttr<MatrixAccessDirectionAttr>:$vec_direction,
1360+
OptionalAttr<UnitAttr>:$subgroup_block_io,
13551361
OptionalAttr<DistributeLayoutAttr>:$layout
13561362
);
13571363
let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
@@ -1379,7 +1385,10 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
13791385
}
13801386

13811387
ArrayRef<int64_t> getDataShape() {
1382-
return getData().getType().getShape();
1388+
auto DataTy = getData().getType();
1389+
if (auto vecTy = llvm::dyn_cast<VectorType>(DataTy))
1390+
return vecTy.getShape();
1391+
return {};
13831392
}
13841393

13851394
}];

mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ add_mlir_conversion_library(MLIRXeGPUToXeVM
2121
MLIRIndexDialect
2222
MLIRSCFDialect
2323
MLIRXeGPUDialect
24+
MLIRXeGPUUtils
2425
MLIRPass
2526
MLIRTransforms
2627
MLIRSCFTransforms

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 70 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/Dialect/SCF/IR/SCF.h"
2222
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
2323
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
24+
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
2425
#include "mlir/Pass/Pass.h"
2526
#include "mlir/Support/LLVM.h"
2627
#include "llvm/Support/FormatVariadic.h"
@@ -371,8 +372,6 @@ static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
371372
Value baseAddr, Value offset, int64_t elemByteSize) {
372373
Value byteSize = arith::ConstantIntOp::create(
373374
rewriter, loc, rewriter.getI64Type(), elemByteSize);
374-
offset = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(),
375-
offset);
376375
Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
377376
Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
378377
return newAddr;
@@ -583,6 +582,8 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
583582
else
584583
data = adaptor.getData();
585584
VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
585+
if (!valOrResVecTy)
586+
valOrResVecTy = VectorType::get(1, data.getType());
586587

587588
int64_t elemBitWidth =
588589
valOrResVecTy.getElementType().getIntOrFloatBitWidth();
@@ -606,22 +607,81 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
606607
rewriter, loc, rewriter.getI64Type(), basePtrLLVM);
607608

608609
Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
610+
linearOffset = arith::IndexCastUIOp::create(
611+
rewriter, loc, rewriter.getI64Type(), linearOffset);
609612
basePtrI64 =
610613
addOffset(rewriter, loc, basePtrI64, linearOffset, elemByteSize);
611614

612615
// convert base pointer (i64) to LLVM pointer type
613616
basePtrLLVM =
614617
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
615618

616-
if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
617-
618-
Value loadOp =
619-
LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
620-
rewriter.replaceOp(op, loadOp);
619+
// if the size of valOrResVecTy is 1, it lowers to a scalar load/store
620+
// operation. LLVM load/store does not support vector of size 1, so we need
621+
// to handle this case separately.
622+
if (valOrResVecTy.getNumElements() == 1) {
623+
Type scalarTy = valOrResVecTy.getElementType();
624+
if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
625+
Value loadOp =
626+
LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM);
627+
rewriter.replaceOp(op, loadOp);
628+
} else {
629+
auto storeOp = LLVM::StoreOp::create(rewriter, loc, adaptor.getData(),
630+
basePtrLLVM);
631+
rewriter.eraseOp(op);
632+
}
633+
return success();
621634
} else {
622-
auto storeOp =
623-
LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
624-
rewriter.eraseOp(op);
635+
// if the attribute 'subgroup_block_io' is set to true, it lowers to
636+
// xevm.blockload
637+
auto subgroupBlockIoAttr = op.getSubgroupBlockIoAttr();
638+
bool subgroup_block_io =
639+
subgroupBlockIoAttr && cast<BoolAttr>(subgroupBlockIoAttr).getValue();
640+
if (subgroup_block_io) {
641+
if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
642+
Value loadOp = xevm::BlockLoadOp::create(rewriter, loc, valOrResVecTy,
643+
basePtrLLVM);
644+
rewriter.replaceOp(op, loadOp);
645+
} else {
646+
xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM,
647+
adaptor.getData(), nullptr);
648+
rewriter.eraseOp(op);
649+
}
650+
} else {
651+
// if the result is 1D vector, if the vector direction is Column, then
652+
// the
653+
// memory descriptor should be treated as column major
654+
auto chipOpt = xegpu::getChipStr(op);
655+
if (!chipOpt || (*chipOpt != "pvc" && *chipOpt != "bmg")) {
656+
// the lowering only works for pvc and bmg
657+
return rewriter.notifyMatchFailure(
658+
op, "The lowering is specific to pvc or bmg.");
659+
}
660+
xegpu::MatrixAccessDirectionAttr vecDirection =
661+
op.getVecDirectionAttr();
662+
if (vecDirection &&
663+
vecDirection.getValue() == xegpu::MatrixAccessDirection::COL &&
664+
!mdescTy.isColMajor())
665+
return rewriter.notifyMatchFailure(
666+
op, "mem_desc should be column major when "
667+
"vec_direction is COLUMN for 1D result.");
668+
if (vecDirection &&
669+
vecDirection.getValue() == xegpu::MatrixAccessDirection::ROW &&
670+
mdescTy.isColMajor())
671+
return rewriter.notifyMatchFailure(
672+
op, "mem_desc should be row major when "
673+
"vec_direction is ROW for 1D result.");
674+
675+
if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
676+
Value loadOp =
677+
LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
678+
rewriter.replaceOp(op, loadOp);
679+
} else {
680+
auto storeOp = LLVM::StoreOp::create(rewriter, loc, adaptor.getData(),
681+
basePtrLLVM);
682+
rewriter.eraseOp(op);
683+
}
684+
}
625685
}
626686
return success();
627687
}

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -813,9 +813,8 @@ SmallVector<int64_t> MemDescType::getStrides() {
813813
}
814814
llvm::dbgs() << "]\n";
815815

816-
if (innerBlkShape.empty())
817-
return strides;
818-
816+
// get perm from FCD to LCD
817+
// perm[i] = the dim with i-th smallest stride
819818
SmallVector<int, 4> perm =
820819
llvm::to_vector<4>(llvm::seq<int>(0, strides.size()));
821820
llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; });
@@ -908,6 +907,7 @@ SmallVector<int64_t> MemDescType::getStrides() {
908907
Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
909908
ArrayRef<OpFoldResult> offsets) {
910909

910+
SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
911911
SmallVector<int64_t> blockShape = getBlockSize();
912912
SmallVector<int64_t> strides = getStrides();
913913

@@ -917,7 +917,11 @@ Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
917917
llvm::interleaveComma(strides, llvm::dbgs());
918918
llvm::dbgs() << "]\n");
919919

920-
if (!blockShape.empty()) {
920+
// blockshape equal to matrixshape means no blocking
921+
if (llvm::equal(blockShape, matrixShape)) {
922+
// remove the outer dims from strides
923+
strides.erase(strides.begin(), strides.begin() + matrixShape.size());
924+
} else {
921925
assert(offsets.size() == blockShape.size() &&
922926
"offsets and blockShape must have the same size");
923927
// say the original offset is [y, x], and the block shape is [By, Bx],

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 64 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,51 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
173173
return success();
174174
}
175175

176+
LogicalResult IsValidStoreMatrixParams(
177+
VectorType dataTy, MemDescType mdescTy, UnitAttr subgroup_block_io,
178+
MatrixAccessDirectionAttr vecDirection, IntegerAttr vecLength,
179+
function_ref<InFlightDiagnostic()> emitError) {
180+
181+
if (!dataTy)
182+
if (subgroup_block_io || vecDirection || vecLength)
183+
return emitError() << "vec_length, vec_direction and subgroup_block_io "
184+
"are only allowed when result is a 1D VectorType.";
185+
else
186+
return success();
187+
188+
if (mdescTy.getRank() != 2)
189+
return emitError() << "mem_desc must be 2D.";
190+
191+
ArrayRef<int64_t> dataShape = dataTy.getShape();
192+
ArrayRef<int64_t> mdescShape = mdescTy.getShape();
193+
194+
if (dataShape.size() == 2) {
195+
if (subgroup_block_io || vecDirection || vecLength)
196+
return emitError() << "vec_length, vec_direction and subgroup_block_io "
197+
"are only allowed when result is a 1D VectorType.";
198+
if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
199+
[](auto p) { return std::get<0>(p) > std::get<1>(p); }))
200+
return emitError() << "data shape must not exceed mem_desc shape.";
201+
} else if (dataShape.size() == 1) {
202+
203+
SmallVector<int64_t> blockSize = mdescTy.getBlockSize();
204+
// if the subgroup_block_io attribute is set, mdescTy must have block
205+
// attribute
206+
if (subgroup_block_io && !blockSize.size())
207+
return emitError() << "mem_desc must have block attribute when "
208+
"subgroup_block_io is set.";
209+
// if the subgroup_block_io attribute is set, the memdesc should be row
210+
// major
211+
if (subgroup_block_io && mdescTy.isColMajor())
212+
return emitError() << "mem_desc should be row major when "
213+
"subgroup_block_io is set.";
214+
} else if (dataShape.size() == 0) {
215+
return emitError() << "result shape must not be empty.";
216+
}
217+
218+
return success();
219+
}
220+
176221
//===----------------------------------------------------------------------===//
177222
// XeGPU_CreateNdDescOp
178223
//===----------------------------------------------------------------------===//
@@ -1053,25 +1098,20 @@ void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
10531098
// nullptr/empty)
10541099
build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
10551100
/*vec_length=*/nullptr, /*vec_direction=*/nullptr,
1056-
/*subgroupBlockIO=*/nullptr, layout);
1101+
/*subgroup_block_io=*/nullptr, layout);
10571102
}
10581103

10591104
LogicalResult LoadMatrixOp::verify() {
1060-
VectorType resTy = getRes().getType();
1061-
MemDescType mdescTy = getMemDesc().getType();
1062-
1063-
if (mdescTy.getRank() != 2)
1064-
return emitOpError("mem_desc must be 2D.");
10651105

1066-
ArrayRef<int64_t> valueShape = resTy.getShape();
1067-
ArrayRef<int64_t> mdescShape = mdescTy.getShape();
1106+
auto resTy = dyn_cast<VectorType>(getRes().getType());
1107+
UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
1108+
MatrixAccessDirectionAttr vecDirection = getVecDirectionAttr();
1109+
IntegerAttr vecLength = getVecLengthAttr();
1110+
MemDescType mdescTy = getMemDesc().getType();
10681111

1069-
if (valueShape.size() != 1) {
1070-
if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape),
1071-
[](auto p) { return std::get<0>(p) > std::get<1>(p); }))
1072-
return emitOpError("result shape must not exceed mem_desc shape.");
1073-
}
1074-
return success();
1112+
return IsValidStoreMatrixParams(resTy, mdescTy, subgroup_block_io,
1113+
vecDirection, vecLength,
1114+
[&]() { return emitError(); });
10751115
}
10761116

10771117
//===----------------------------------------------------------------------===//
@@ -1086,24 +1126,20 @@ void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
10861126
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
10871127
auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
10881128
build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,
1089-
layout);
1129+
/*vec_length=*/nullptr, /*vec_direction=*/nullptr,
1130+
/*subgroup_block_io=*/nullptr, layout);
10901131
}
10911132

10921133
LogicalResult StoreMatrixOp::verify() {
1093-
VectorType dataTy = getData().getType();
1094-
MemDescType mdescTy = getMemDesc().getType();
10951134

1096-
if (mdescTy.getRank() != 2)
1097-
return emitOpError("mem_desc must be 2D.");
1098-
1099-
ArrayRef<int64_t> dataShape = dataTy.getShape();
1100-
ArrayRef<int64_t> mdescShape = mdescTy.getShape();
1101-
if (dataShape.size() != 1) {
1102-
if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
1103-
[](auto p) { return std::get<0>(p) > std::get<1>(p); }))
1104-
return emitOpError("data shape must not exceed mem_desc shape.");
1105-
}
1106-
return success();
1135+
auto dataTy = dyn_cast<VectorType>(getData().getType());
1136+
UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
1137+
MatrixAccessDirectionAttr vecDirection = getVecDirectionAttr();
1138+
IntegerAttr vecLength = getVecLengthAttr();
1139+
MemDescType mdescTy = getMemDesc().getType();
1140+
return IsValidStoreMatrixParams(dataTy, mdescTy, subgroup_block_io,
1141+
vecDirection, vecLength,
1142+
[&]() { return emitError(); });
11071143
}
11081144

11091145
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -941,7 +941,7 @@ struct UnrollLoadMatrixOp : public UnrollPattern<xegpu::LoadMatrixOp> {
941941
LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op,
942942
PatternRewriter &rewriter) const override {
943943
Location loc = op.getLoc();
944-
VectorType valueTy = op.getType();
944+
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
945945
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
946946
if (!targetShape || targetShape->size() != (size_t)valueTy.getRank())
947947
return failure();
@@ -984,7 +984,7 @@ struct UnrollStoreMatrixOp : public UnrollPattern<xegpu::StoreMatrixOp> {
984984
return failure();
985985

986986
Location loc = op.getLoc();
987-
VectorType valueTy = op.getData().getType();
987+
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getData().getType());
988988
ArrayRef<int64_t> shape = valueTy.getShape();
989989
auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
990990

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -867,7 +867,7 @@ struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
867867
return failure();
868868

869869
ArrayRef<int64_t> wgShape = op.getDataShape();
870-
VectorType valueTy = op.getRes().getType();
870+
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType());
871871
Type elemTy = valueTy.getElementType();
872872

873873
xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();

mlir/test/Conversion/XeGPUToXeVM/dpas.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ gpu.module @test_kernel {
77
// Loads are checked in a separate test.
88
// CHECK: %[[D:.*]] = xevm.mma %[[ARG0]], %[[ARG1]], %[[ARG2]] {shape = <m = 8, n = 16, k = 16>, types = <d = f32, a = f16, b = f16, c = f32>}
99
// CHECK-SAME: : (vector<8xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32>
10-
%d = xegpu.dpas %a_loaded, %b_loaded, %c_loaded {a_layout = #sg_map_a_f16, b_layout = #sg_map_b_f16, c_layout = #sg_map_c_f32}
10+
%d = xegpu.dpas %a_loaded, %b_loaded, %c_loaded
1111
: vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
1212
return %d : vector<8xf32>
1313
}

0 commit comments

Comments
 (0)