Skip to content
Open
45 changes: 45 additions & 0 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,21 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
: !xegpu.tensor_desc<8x16xf16>
```

The operation may take optional offsets for the tensor descriptor.
The number of offsets must be greater or equal to the rank of the tensor descriptor
and less than the rank of the source memref. The offsets are applied to the innermost
dimension of the source memref.

Examples:
```mlir
%tdesc = xegpu.create_nd_tdesc %0: memref<2x8x32x32xf32> -> TensorDesc<8x16xf32>
// memref[0, 0, %off0, %off1]
xegpu.prefetch_nd %tdesc[%off0, %off1] : !xegpu.tensor_desc<8x16xf16>
// memref[0, %off0, %off1, %off2]
xegpu.prefetch_nd %tdesc[%off0, %off1, %off2] : !xegpu.tensor_desc<8x16xf16>
// memref[%off0, %off1, %off2, %off3]
xegpu.prefetch_nd %tdesc[%off0, %off1, %off2] : !xegpu.tensor_desc<8x16xf16>
```
}];

let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
Expand Down Expand Up @@ -350,6 +365,21 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
: !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
```

The operation may take optional offsets for the tensor descriptor.
The number of offsets must be greater or equal to the rank of the tensor descriptor
and less than the rank of the source memref. The offsets are applied to the innermost
dimension of the source memref.

Examples:
```mlir
%1 = xegpu.create_nd_tdesc %0: memref<2x8x32x32xf32> -> TensorDesc<8x16xf32>
// memref[0, 0, %off0, %off1]
xegpu.load_nd %1[%off0, %off1] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
// memref[0, %off0, %off1, %off2]
xegpu.load_nd %1[%off0, %off1, %off2] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
// memref[%off0, %off1, %off2, %off3]
xegpu.load_nd %1[%off0, %off1, %off2, %off3] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
```

}];

Expand Down Expand Up @@ -445,6 +475,21 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
: vector<8xf16>, !xegpu.tensor_desc<8x16xf16>
```

The operation may take optional offsets for the tensor descriptor.
The number of offsets must be greater or equal to the rank of the tensor descriptor
and less than the rank of the source memref. The offsets are applied to the innermost
dimension of the source memref.

Examples:
```mlir
%2 = xegpu.create_nd_tdesc %0: memref<2x8x32x32xf32> -> TensorDesc<8x16xf32>
// memref[0, 0, %off0, %off1]
xegpu.store_nd %3, %2[%off0, %off1] : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
// memref[0, %off0, %off1, %off2]
xegpu.store_nd %3, %2[%off0, %off1, %off2] : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
// memref[%off0, %off1, %off2, %off3]
xegpu.store_nd %3, %2[%off0, %off1, %off2, %off3] : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
```

}];

Expand Down
99 changes: 41 additions & 58 deletions mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,17 +97,16 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
return success();
}

static xegpu::CreateNdDescOp
createNdDescriptor(PatternRewriter &rewriter, Location loc,
xegpu::TensorDescType descType, TypedValue<MemRefType> src,
Operation::operand_range offsets) {
static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
Location loc,
xegpu::TensorDescType descType,
TypedValue<MemRefType> src) {
MemRefType srcTy = src.getType();
auto [strides, offset] = srcTy.getStridesAndOffset();

xegpu::CreateNdDescOp ndDesc;
if (srcTy.hasStaticShape()) {
ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
getAsOpFoldResult(offsets));
ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
} else {
// In case of any dynamic shapes, source's shape and strides have to be
// explicitly provided.
Expand All @@ -116,38 +115,19 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
for (unsigned i = 0; i < srcRank; ++i)
sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i));

SmallVector<int64_t> constOffsets;
SmallVector<Value> dynOffsets;
for (Value offset : offsets) {
std::optional<int64_t> staticVal = getConstantIntValue(offset);
if (!staticVal)
dynOffsets.push_back(offset);
constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic));
}

SmallVector<Value> dynShapes;
SmallVector<OpFoldResult> mixedShapes;
for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
if (shape == ShapedType::kDynamic)
dynShapes.push_back(sourceDims[idx]);
}

// Compute strides in reverse order.
SmallVector<Value> dynStrides;
Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1);
// Last stride is guaranteed to be static and unit.
for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
accStride =
arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]);
if (strides[i] == ShapedType::kDynamic)
dynStrides.push_back(accStride);
mixedShapes.push_back(sourceDims[idx]);
else
mixedShapes.push_back(rewriter.getI64IntegerAttr(shape));
}
std::reverse(dynStrides.begin(), dynStrides.end());

ndDesc = xegpu::CreateNdDescOp::create(
rewriter, loc, descType, src, dynOffsets, dynShapes, dynStrides,
DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets),
DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()),
DenseI64ArrayAttr::get(rewriter.getContext(), strides));
auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
SmallVector<OpFoldResult> mixedStrides(meta.getStrides().begin(),
meta.getStrides().end());
ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
mixedShapes, mixedStrides);
}

return ndDesc;
Expand Down Expand Up @@ -523,21 +503,21 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
descShape, elementType, /*array_length=*/1,
/*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global);

xegpu::CreateNdDescOp ndDesc =
createNdDescriptor(rewriter, loc, descType,
dyn_cast<TypedValue<MemRefType>>(readOp.getBase()),
readOp.getIndices());

DenseI64ArrayAttr transposeAttr =
!isTransposeLoad ? nullptr
: DenseI64ArrayAttr::get(rewriter.getContext(),
ArrayRef<int64_t>{1, 0});
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
/*packed=*/nullptr, transposeAttr,
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
xegpu::CreateNdDescOp ndDesc =
createNdDescriptor(rewriter, loc, descType,
dyn_cast<TypedValue<MemRefType>>(readOp.getBase()));

auto loadOp = xegpu::LoadNdOp::create(
rewriter, loc, vecTy, ndDesc, getAsOpFoldResult(readOp.getIndices()),
/*packed=*/nullptr, transposeAttr,
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
rewriter.replaceOp(readOp, loadOp);

return success();
Expand Down Expand Up @@ -579,15 +559,15 @@ struct TransferWriteLowering
vecTy.getShape(), vecTy.getElementType(),
/*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
xegpu::MemorySpace::Global);
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
xegpu::CreateNdDescOp ndDesc =
createNdDescriptor(rewriter, loc, descType,
dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()),
writeOp.getIndices());
dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()));

// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
auto storeOp =
xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
getAsOpFoldResult(writeOp.getIndices()),
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
rewriter.replaceOp(writeOp, storeOp);
Expand Down Expand Up @@ -674,17 +654,18 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> {

// Boundary check is available only for block instructions.
bool boundaryCheck = vecTy.getRank() > 1;
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;

auto descType = xegpu::TensorDescType::get(
vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
boundaryCheck, xegpu::MemorySpace::Global);
xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());

// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
xegpu::CreateNdDescOp ndDesc =
createNdDescriptor(rewriter, loc, descType, loadOp.getBase());
auto loadNdOp = xegpu::LoadNdOp::create(
rewriter, loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr,
rewriter, loc, vecTy, ndDesc, getAsOpFoldResult(loadOp.getIndices()),
/*packed=*/nullptr, /*transpose=*/nullptr,
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
rewriter.replaceOp(loadOp, loadNdOp);
Expand All @@ -711,15 +692,17 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
auto descType = xegpu::TensorDescType::get(
vecTy.getShape(), vecTy.getElementType(),
/*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global);
xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());

// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
auto storeNdOp =
xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
xegpu::CreateNdDescOp ndDesc =
createNdDescriptor(rewriter, loc, descType, storeOp.getBase());

auto storeNdOp = xegpu::StoreNdOp::create(
rewriter, loc, vector, ndDesc, getAsOpFoldResult(storeOp.getIndices()),
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);

rewriter.replaceOp(storeOp, storeNdOp);

return success();
Expand Down
67 changes: 33 additions & 34 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,22 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
return success();
}

// Verify that number of offsets matches either the source rank or the tdesc
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should simply check that offsets can't have lower rank than tdesc rank. It shoud be fine if it is larger than tdesc rank.

I am not sure that it is common practice for the op validation to validate itself using information from producer op. I would rather checking this in transformation or lowering passes. Any opinion from @adam-smnk ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Op verifier should be contained to the op itself without accessing any external data.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should simply check that offsets can't have lower rank than tdesc rank. It shoud be fine if it is larger than tdesc rank.

Sound reasonable. Just please clearly define how offsets are interpreted.
For example, memref is 4D, descriptor is 2D and there are 3 offsets. I guess they'd apply to the innermost dimensions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simplified the check to not rely on CreateNdTdescOp.

Also added docs for the new offset syntax

// rank.
static LogicalResult
isValidNdOffset(TypedValue<TensorDescType> tDesc,
std::optional<llvm::ArrayRef<int64_t>> constOffsets,
int64_t offsetSize,
function_ref<InFlightDiagnostic()> emitError) {
int64_t constOffsetSize = constOffsets ? constOffsets->size() : 0;
auto tDescRank = tDesc.getType().getRank();
if (((offsetSize != 0) && (offsetSize < tDescRank)) ||
((constOffsetSize != 0) && (constOffsetSize < tDescRank)))
return emitError() << "Offsets rank cannot be smaller than tensor "
"descriptor rank.";
return success();
}

static LogicalResult
isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
VectorType valueTy, int64_t chunkSize,
Expand Down Expand Up @@ -215,8 +231,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
auto [memrefStrides, _] = memrefTy.getStridesAndOffset();

// if shape and strides are from Memref, we don't need attributes for them
// to keep the IR print clean.
if (staticShape == memrefShape && staticStrides == memrefStrides) {
// to keep the IR print clean (only do so for full-static case, otherwise
// printer would fail trying to print empty array-attr).
if (staticShape == memrefShape && staticStrides == memrefStrides &&
dynamicShape.empty() && dynamicStrides.empty()) {
staticShapeAttr = DenseI64ArrayAttr();
staticStridesAttr = DenseI64ArrayAttr();
}
Expand Down Expand Up @@ -277,8 +295,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
auto [memrefStrides, _] = memrefTy.getStridesAndOffset();

// if shape and strides are from Memref, we don't need attributes for them
// to keep the IR print clean.
if (staticShape == memrefShape && staticStrides == memrefStrides) {
// to keep the IR print clean (only do so for full-static case, otherwise
// printer would fail trying to print empty array-attr).
if (staticShape == memrefShape && staticStrides == memrefStrides &&
dynamicShape.empty() && dynamicStrides.empty()) {
staticShapeAttr = DenseI64ArrayAttr();
staticStridesAttr = DenseI64ArrayAttr();
}
Expand Down Expand Up @@ -428,16 +448,9 @@ LogicalResult PrefetchNdOp::verify() {
if (!isReadHintOrNone(getL3HintAttr()))
return emitOpError("invalid l3_hint: ") << getL3HintAttr();

int64_t tDescRank = tdescTy.getRank();
int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
int64_t constOffsetSize =
getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
return emitOpError(
"Mismatched ranks between offsets and tensor descriptor");

return success();
auto tDesc = getTensorDesc();
return isValidNdOffset(tDesc, getConstOffsets(), getMixedOffsets().size(),
[&]() { return emitOpError(); });
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -553,16 +566,9 @@ LogicalResult LoadNdOp::verify() {
<< " is not consistent with tensor descriptor "
<< tdescTy;

int64_t tDescRank = tdescTy.getRank();
int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
int64_t constOffsetSize =
getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
return emitOpError(
"Mismatched ranks between offsets and tensor descriptor");

return success();
auto tDesc = getTensorDesc();
return isValidNdOffset(tDesc, getConstOffsets(), getMixedOffsets().size(),
[&]() { return emitOpError(); });
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -647,16 +653,9 @@ LogicalResult StoreNdOp::verify() {
<< " is not consistent with tensor descriptor "
<< dstTy;

int64_t tDescRank = dstTy.getRank();
int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
int64_t constOffsetSize =
getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
return emitOpError(
"Mismatched ranks between offsets and tensor descriptor");

return success();
auto tDesc = getTensorDesc();
return isValidNdOffset(tDesc, getConstOffsets(), getMixedOffsets().size(),
[&]() { return emitOpError(); });
}

//===----------------------------------------------------------------------===//
Expand Down
20 changes: 10 additions & 10 deletions mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ func.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vecto
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
// CHECK-SAME: %[[SRC]]
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
// CHECK-SAME: boundary_check = false
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf32>
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8xf32>
// CHECK: return %[[VEC]]

// -----
Expand All @@ -29,9 +29,9 @@ func.func @load_2D_vector(%source: memref<8x16x32xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
// CHECK-SAME: %[[SRC]]
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]

// -----
Expand All @@ -52,11 +52,11 @@ func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
// CHECK-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
// CHECK-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
// CHECK: {{.*}} %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, %c1]
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]

// -----
Expand All @@ -72,9 +72,9 @@ func.func @load_out_of_bounds(%source: memref<7x15xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<7x15xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
// CHECK-SAME: %[[SRC]]
// CHECK-SAME: memref<7x15xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]

// -----
Expand Down
Loading