Skip to content

Commit c858863

Browse files
authored
Add transpose_bit_width attribute to xegpu.load_nd (#693)
1 parent 6deb412 commit c858863

File tree

8 files changed

+85
-16
lines changed

8 files changed

+85
-16
lines changed

include/imex/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,14 +218,16 @@ def XeGPU_LoadNDOp : XeGPU_Op<"load_nd"> {
218218
XeGPU_TensorDesc: $TensorDesc,
219219
OptionalAttr<I32Attr>: $vnni_axis,
220220
OptionalAttr<DenseI64ArrayAttr>: $transpose,
221+
OptionalAttr<I32Attr>: $transpose_bit_width,
221222
OptionalAttr<XeGPU_CacheReadAttr>: $l1_hint,
222223
OptionalAttr<XeGPU_CacheReadAttr>: $l2_hint,
223224
OptionalAttr<XeGPU_CacheReadAttr>: $l3_hint,
224225
DefaultValuedAttr<XeGPU_ModeAttr, "imex::xegpu::Mode::SIMT">: $mode);
226+
225227
let results = (outs XeGPU_ValueType: $value);
226228

227229
let extraClassDeclaration = [{
228-
mlir::VectorType getValueType() {
230+
mlir::VectorType getType() {
229231
return llvm::dyn_cast_if_present<mlir::VectorType>(getValue().getType());
230232
}
231233

lib/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -661,12 +661,13 @@ class LoadStorePrefetchNdToRawSend : public OpConversionPattern<OpType> {
661661
auto extMsg = createIntConstant(i32Type, 0);
662662
auto dataSize2D = (encodeDataum(elmType) - 1);
663663
auto payLoad = adaptor.getTensorDesc();
664-
// vnni and transpose combination is required for the case where B matrix is
665-
// transposed and we need to load from B in DPAS layout. However, HW does
666-
// not support both vnni and transpose together. We can get the same layout
667-
// for the B load by doing the transpose in 32 bit granularity.
668-
// TODO: Transpose granularity must be explicitly represented in XeGPU op.
669-
if (vnni && transpose) {
664+
665+
// TODO: currently limit transposeBitWidth to 32, it is
666+
// an architecture feature, and 32 works on PVC but may
667+
// be not FS. To support other bits, we cannot hardcode
668+
// with i32Type, and need to generalize the logic.
669+
auto loadOp = llvm::dyn_cast<LoadNDOp>(op.getOperation());
670+
if (loadOp && transpose && loadOp.getTransposeBitWidth() == 32) {
670671
// in raw_send msg set vnni effect to false and update data size of
671672
// payload item to 32 bits
672673
vnni = false;

lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,7 @@ struct SgLoadTileOpPattern
440440
auto L3 = xegpu::CacheReadHintAttr::get(ctx, xegpu::CacheReadHint::CACHED);
441441

442442
mlir::IntegerAttr vnniAttr;
443+
mlir::IntegerAttr transposeBitWidthAttr;
443444
// TODO: move these two into architecture abstracture in future.
444445
const int SIMD_WIDTH_IN_BITS = 32;
445446
int factor = SIMD_WIDTH_IN_BITS / elemTy.getIntOrFloatBitWidth();
@@ -471,8 +472,8 @@ struct SgLoadTileOpPattern
471472

472473
auto vectorTy = mlir::VectorType::get(shape, tileTy.getElementType());
473474
auto ldOp = rewriter.create<xegpu::LoadNDOp>(
474-
op.getLoc(), vectorTy, src, vnniAttr, transposeAttr, L1, L2, L3,
475-
imex::xegpu::Mode::VC);
475+
op.getLoc(), vectorTy, src, vnniAttr, transposeAttr,
476+
transposeBitWidthAttr, L1, L2, L3, imex::xegpu::Mode::VC);
476477
if (array_length == 1) {
477478
xegpuOps.push_back(ldOp);
478479
} else {

lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,8 @@ parseOptionalAttrDict(mlir::OpAsmParser &parser, mlir::OperationState &result,
140140
return parseCustomEnumAttr<Mode, ModeAttr>(parser, result, nameId);
141141
}
142142

143-
if (nameId == "chunk_size_per_lane" || nameId == "vnni_axis")
143+
if (nameId == "chunk_size_per_lane" || nameId == "vnni_axis" ||
144+
nameId == "transpose_bit_width")
144145
return parseBoolAndIntegerAttr<mlir::IntegerAttr>(parser, result, nameId);
145146

146147
if (nameId == "boundary_check")
@@ -727,9 +728,10 @@ mlir::ParseResult LoadNDOp::parse(mlir::OpAsmParser &parser,
727728
if (parser.parseOperand(TensorDescRawOperands[0]))
728729
return mlir::failure();
729730

730-
if (parseOptionalAttrDict(
731-
parser, result,
732-
{"mode", "vnni_axis", "transpose", "l1_hint", "l2_hint", "l3_hint"}))
731+
if (parseOptionalAttrDict(parser, result,
732+
{"mode", "vnni_axis", "transpose",
733+
"transpose_bit_width", "l1_hint", "l2_hint",
734+
"l3_hint"}))
733735
return mlir::failure();
734736

735737
if (parser.parseColon())
@@ -789,6 +791,13 @@ void LoadNDOp::print(mlir::OpAsmPrinter &printer) {
789791
printSep = true;
790792
}
791793

794+
if (getTransposeBitWidthAttr()) {
795+
if (printSep)
796+
printer << "," << ' ';
797+
printer << "transpose_bit_width = " << getTransposeBitWidth().value();
798+
printSep = true;
799+
}
800+
792801
printCacheHintAttrs<LoadNDOp>(printer, *this, printSep);
793802

794803
if (printDefaults || mode != imex::xegpu::Mode::SIMT || numAttrs > 1) {
@@ -805,7 +814,7 @@ void LoadNDOp::print(mlir::OpAsmPrinter &printer) {
805814

806815
mlir::LogicalResult LoadNDOp::verify() {
807816
auto tdescTy = getTensorDescType();
808-
auto valueTy = getValueType();
817+
auto valueTy = getType();
809818

810819
if (tdescTy.getRank() > 2)
811820
return emitOpError(
@@ -845,6 +854,17 @@ mlir::LogicalResult LoadNDOp::verify() {
845854
}
846855
}
847856

857+
// TODO: remove the following two checks when we have a verfier
858+
// against a architecture for handwritten code.
859+
if (getTranspose() == llvm::ArrayRef<int64_t>({1, 0}) && getVnniAxis() == 0) {
860+
return emitOpError("Transpose and VNNI are mutually exclusive.");
861+
}
862+
863+
if (getVnniAxis() == 0 && getTransposeBitWidth()) {
864+
return emitOpError("TransposeBitWidth and VNNI are mutually exclusive. "
865+
"TransposeBitWidth implies a VNNI transform on axis 0.");
866+
}
867+
848868
if (getTranspose()) {
849869
auto trans = getTranspose().value();
850870
if (tdescShape.size() >= trans.size())
@@ -860,6 +880,16 @@ mlir::LogicalResult LoadNDOp::verify() {
860880
tdescShape.push_back(vnni_factor);
861881
}
862882

883+
if (getTransposeBitWidth()) {
884+
auto bitWidth = getTransposeBitWidth().value();
885+
if (bitWidth != 32)
886+
return emitOpError("Invalid bit width for transpose.");
887+
auto vnni_factor = valueShape.back();
888+
// transpose_bit_width imply a vnni transform on axis 0
889+
tdescShape[0] /= vnni_factor;
890+
tdescShape.push_back(vnni_factor);
891+
}
892+
863893
if (array_len > 1) {
864894
auto it = tdescShape.begin();
865895
tdescShape.insert(it, array_len);

lib/Utils/XeArch.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,14 @@ mlir::LogicalResult XeuArchInterface::isLegalLoad2dOp(mlir::Operation *op) {
307307

308308
LoadStore2DConfig loadParams;
309309
bool vnni = loadOp.getVnniAxis() == 0 ? true : false;
310-
bool transpose = loadOp.getTranspose() ? true : false;
310+
bool transpose =
311+
loadOp.getTranspose() == llvm::ArrayRef<int64_t>({1, 0}) ? true : false;
312+
313+
if (vnni && transpose) {
314+
return loadOp->emitOpError(
315+
"Transpose and VNNI are mutually exclusive. They are "
316+
"not supported by the PVC hardware at the same time.\n");
317+
}
311318

312319
mlir::FailureOr<LoadStore2DConfig> configParams =
313320
this->get2DLoadConfig(op, elementSize, vnni, transpose);

test/Dialect/XeGPU/IR/invalid_vc.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,15 @@ func.func @test_load_gather(%src: ui64, %offsets : vector<16xindex>) {
6868
: !xegpu.tensor_desc<16x8xf16, #xegpu.scattered>, vector<16x8xi1> -> vector<8x8x4xf16>
6969
return
7070
}
71+
72+
// -----
73+
func.func @test_load_nd(%input: memref<24x32xf16>) {
74+
%c0 = arith.constant 0 : index
75+
%1 = xegpu.create_nd_tdesc %input[%c0, %c0] {mode = vc}
76+
: memref<24x32xf16> -> !xegpu.tensor_desc<16x16xf16>
77+
// Hardware doesn't support VNNI transform and transpose at the same time.
78+
// expected-error@+1 {{Transpose and VNNI are mutually exclusive.}}
79+
%2 = xegpu.load_nd %1 {mode = vc, vnni_axis = 0, transpose = [1, 0]}
80+
: !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16>
81+
return
82+
}

test/Dialect/XeGPU/IR/load_nd_vc.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,19 @@ func.func @test_load_nd_block_array_simd_f16(%src: memref<8x32xf16>) {
7575
: !xegpu.tensor_desc<8x16xf16, #xegpu.tdesc_attr<array_length = 2>> -> vector<2x8x16xf16>
7676
return
7777
}
78+
79+
80+
// CHECK-LABEL: func @test_load_nd_transpose_bit_width_simd_f16({{.*}}) {
81+
func.func @test_load_nd_transpose_bit_width_simd_f16(%src: memref<8x32xf16>) {
82+
// CHECK: xegpu.create_nd_tdesc
83+
// CHECK-SAME: {mode = vc}
84+
// CHECK-SAME: memref<8x32xf16> -> !xegpu.tensor_desc<8x32xf16>
85+
%1 = xegpu.create_nd_tdesc %src[0, 0] {mode = vc} : memref<8x32xf16> -> !xegpu.tensor_desc<8x32xf16>
86+
87+
// CHECK: xegpu.load_nd
88+
// CHECK-SAME: {mode = vc, transpose = [1, 0], transpose_bit_width = 32, l1_hint = cached, l2_hint = uncached}
89+
// CHECK-SAME: !xegpu.tensor_desc<8x32xf16> -> vector<16x8x2xf16>
90+
%2 = xegpu.load_nd %1 {mode = vc, transpose = [1, 0], transpose_bit_width = 32, l1_hint = cached, l2_hint = uncached}
91+
: !xegpu.tensor_desc<8x32xf16> -> vector<16x8x2xf16>
92+
return
93+
}

test/Integration/Dialect/XeGPU/gemm_with_transposed_B_1kx1kx1k_f16_f16_f32.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ module @gemm attributes {gpu.container_module} {
4141
%7 = xegpu.create_nd_tdesc %arg0[%2, %arg3] {mode = vc} : memref<1024x1024xf16> -> !xegpu.tensor_desc<8x16xf16>
4242
%8 = xegpu.create_nd_tdesc %arg1[%3, %arg3] {mode = vc} : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x16xf16>
4343
%9 = xegpu.load_nd %7 {mode = vc, vnni_axis = 1}: !xegpu.tensor_desc<8x16xf16> -> vector<8x8x2xf16>
44-
%10 = xegpu.load_nd %8 {mode = vc, vnni_axis = 0, transpose = [1, 0]} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16>
44+
%10 = xegpu.load_nd %8 {mode = vc, transpose_bit_width = 32, transpose = [1, 0]} : !xegpu.tensor_desc<16x16xf16> -> vector<8x16x2xf16>
4545
%11 = xegpu.dpas %9, %10, %arg4 {mode = vc} : vector<8x8x2xf16>, vector<8x16x2xf16>, vector<8x16xf32> -> vector<8x16xf32>
4646
scf.yield %11 : vector<8x16xf32>
4747
}

0 commit comments

Comments
 (0)