-
Couldn't load subscription status.
- Fork 15k
[MLIR][XeGPU][VectorToXeGPU] Lower vector.load/store/transfer_read/transfer_write to new offsets syntax #162095
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…ansfer_write to new offsets syntax Signed-off-by: dchigarev <[email protected]>
Signed-off-by: dchigarev <[email protected]>
Signed-off-by: dchigarev <[email protected]>
2e8e6d7 to
8581183
Compare
Signed-off-by: dchigarev <[email protected]>
Signed-off-by: dchigarev <[email protected]>
|
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Dmitry Chigarev (dchigarev) ChangesChanges the // from this
%desc = xegpu.create_nd_tdesc %src[%off1, %off2]: memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
%res = xegpu.load_nd %desc : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
// to this
%desc = xegpu.create_nd_tdesc %src: memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
%res = xegpu.load_nd %desc[%off1, %off2] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>In order to support cases with dimension reduction at the <details><summary>Why we need to change that?</summary> // reduce dim and apply all 3 offsets at load_nd
%desc = xegpu.create_nd_tdesc %source : memref<8x16x32xf32> -> !xegpu.tensor_desc<16x32xf32>
// error: xegpu.load_nd len(offsets) != desc.rank
%res = xegpu.load_nd %desc[%off, %off, %off] : !xegpu.tensor_desc<16x32xf32> -> vector<8x16xf32></details> The new verification logic checks that the number of offsets either matches the source (e.g. memref's) rank or the tensor descriptor rank. Since Patch is 31.93 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/162095.diff 7 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index e2c7d803e5a5e..7f11d427191e5 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -97,18 +97,17 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
return success();
}
-static xegpu::CreateNdDescOp
-createNdDescriptor(PatternRewriter &rewriter, Location loc,
- xegpu::TensorDescType descType, TypedValue<MemRefType> src,
- Operation::operand_range offsets) {
+static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
+ Location loc,
+ xegpu::TensorDescType descType,
+ TypedValue<MemRefType> src) {
MemRefType srcTy = src.getType();
auto [strides, offset] = srcTy.getStridesAndOffset();
xegpu::CreateNdDescOp ndDesc;
- if (srcTy.hasStaticShape()) {
- ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
- getAsOpFoldResult(offsets));
- } else {
+ if (srcTy.hasStaticShape())
+ ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
+ else {
// In case of any dynamic shapes, source's shape and strides have to be
// explicitly provided.
SmallVector<Value> sourceDims;
@@ -116,38 +115,31 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
for (unsigned i = 0; i < srcRank; ++i)
sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i));
- SmallVector<int64_t> constOffsets;
- SmallVector<Value> dynOffsets;
- for (Value offset : offsets) {
- std::optional<int64_t> staticVal = getConstantIntValue(offset);
- if (!staticVal)
- dynOffsets.push_back(offset);
- constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic));
- }
-
- SmallVector<Value> dynShapes;
+ SmallVector<OpFoldResult> mixedShapes;
for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
if (shape == ShapedType::kDynamic)
- dynShapes.push_back(sourceDims[idx]);
+ mixedShapes.push_back(sourceDims[idx]);
+ else
+ mixedShapes.push_back(rewriter.getI64IntegerAttr(shape));
}
// Compute strides in reverse order.
- SmallVector<Value> dynStrides;
+ SmallVector<OpFoldResult> mixedStrides;
Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1);
// Last stride is guaranteed to be static and unit.
+ mixedStrides.push_back(rewriter.getI64IntegerAttr(1));
for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
accStride =
arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]);
if (strides[i] == ShapedType::kDynamic)
- dynStrides.push_back(accStride);
+ mixedStrides.push_back(accStride);
+ else
+ mixedStrides.push_back(rewriter.getI64IntegerAttr(strides[i]));
}
- std::reverse(dynStrides.begin(), dynStrides.end());
+ std::reverse(mixedStrides.begin(), mixedStrides.end());
- ndDesc = xegpu::CreateNdDescOp::create(
- rewriter, loc, descType, src, dynOffsets, dynShapes, dynStrides,
- DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets),
- DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()),
- DenseI64ArrayAttr::get(rewriter.getContext(), strides));
+ ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
+ mixedShapes, mixedStrides);
}
return ndDesc;
@@ -523,21 +515,21 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
descShape, elementType, /*array_length=*/1,
/*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global);
- xegpu::CreateNdDescOp ndDesc =
- createNdDescriptor(rewriter, loc, descType,
- dyn_cast<TypedValue<MemRefType>>(readOp.getBase()),
- readOp.getIndices());
-
DenseI64ArrayAttr transposeAttr =
!isTransposeLoad ? nullptr
: DenseI64ArrayAttr::get(rewriter.getContext(),
ArrayRef<int64_t>{1, 0});
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
- auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
- /*packed=*/nullptr, transposeAttr,
- /*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
+ xegpu::CreateNdDescOp ndDesc =
+ createNdDescriptor(rewriter, loc, descType,
+ dyn_cast<TypedValue<MemRefType>>(readOp.getBase()));
+
+ auto loadOp = xegpu::LoadNdOp::create(
+ rewriter, loc, vecTy, ndDesc, getAsOpFoldResult(readOp.getIndices()),
+ /*packed=*/nullptr, transposeAttr,
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
rewriter.replaceOp(readOp, loadOp);
return success();
@@ -579,15 +571,15 @@ struct TransferWriteLowering
vecTy.getShape(), vecTy.getElementType(),
/*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
xegpu::MemorySpace::Global);
+ // By default, no specific caching policy is assigned.
+ xegpu::CachePolicyAttr hint = nullptr;
xegpu::CreateNdDescOp ndDesc =
createNdDescriptor(rewriter, loc, descType,
- dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()),
- writeOp.getIndices());
+ dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()));
- // By default, no specific caching policy is assigned.
- xegpu::CachePolicyAttr hint = nullptr;
auto storeOp =
xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
+ getAsOpFoldResult(writeOp.getIndices()),
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
rewriter.replaceOp(writeOp, storeOp);
@@ -674,17 +666,18 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
// Boundary check is available only for block instructions.
bool boundaryCheck = vecTy.getRank() > 1;
+ // By default, no specific caching policy is assigned.
+ xegpu::CachePolicyAttr hint = nullptr;
auto descType = xegpu::TensorDescType::get(
vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
boundaryCheck, xegpu::MemorySpace::Global);
- xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
- rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
- // By default, no specific caching policy is assigned.
- xegpu::CachePolicyAttr hint = nullptr;
+ xegpu::CreateNdDescOp ndDesc =
+ createNdDescriptor(rewriter, loc, descType, loadOp.getBase());
auto loadNdOp = xegpu::LoadNdOp::create(
- rewriter, loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr,
+ rewriter, loc, vecTy, ndDesc, getAsOpFoldResult(loadOp.getIndices()),
+ /*packed=*/nullptr, /*transpose=*/nullptr,
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
rewriter.replaceOp(loadOp, loadNdOp);
@@ -711,15 +704,17 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
auto descType = xegpu::TensorDescType::get(
vecTy.getShape(), vecTy.getElementType(),
/*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global);
- xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
- rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
- auto storeNdOp =
- xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
- /*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
+ xegpu::CreateNdDescOp ndDesc =
+ createNdDescriptor(rewriter, loc, descType, storeOp.getBase());
+
+ auto storeNdOp = xegpu::StoreNdOp::create(
+ rewriter, loc, vector, ndDesc, getAsOpFoldResult(storeOp.getIndices()),
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
+
rewriter.replaceOp(storeOp, storeNdOp);
return success();
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index e0a8ac40648e0..e09a084ac7ad2 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -121,6 +121,39 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
return success();
}
+// Verify that number of offsets matches either the source rank or the tdesc
+// rank.
+static LogicalResult
+isValidNdOffset(TypedValue<TensorDescType> tDesc,
+ std::optional<llvm::ArrayRef<int64_t>> constOffsets,
+ int64_t offsetSize,
+ function_ref<InFlightDiagnostic()> emitError) {
+ if (auto createTDescOp = tDesc.getDefiningOp<CreateNdDescOp>()) {
+ // If CreateNdDescOp is available, we can further
+ // check the offsets rank against the source rank.
+ auto staticSource = createTDescOp.getConstShapeAttr();
+ int64_t sourceRank;
+ if (!staticSource || staticSource.empty()) {
+ auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
+ sourceRank = sourceTy.getRank();
+ } else
+ sourceRank = staticSource.size();
+
+ int64_t constOffsetSize = constOffsets ? constOffsets->size() : 0;
+ auto tDescRank = tDesc.getType().getRank();
+ bool sourceRankMismatch =
+ ((offsetSize != 0) && (offsetSize != sourceRank)) ||
+ ((constOffsetSize != 0) && (constOffsetSize != sourceRank));
+ bool tdescRankMismatch =
+ ((offsetSize != 0) && (offsetSize != tDescRank)) ||
+ ((constOffsetSize != 0) && (constOffsetSize != tDescRank));
+ if (sourceRankMismatch && tdescRankMismatch)
+ return emitError() << "Offsets rank must match either the source or the "
+ "TensorDesc rank.";
+ }
+ return success();
+}
+
static LogicalResult
isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
VectorType valueTy, int64_t chunkSize,
@@ -215,8 +248,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
// if shape and strides are from Memref, we don't need attributes for them
- // to keep the IR print clean.
- if (staticShape == memrefShape && staticStrides == memrefStrides) {
+ // to keep the IR print clean (only do so for full-static case, otherwise
+ // printer would fail trying to print empty array-attr).
+ if (staticShape == memrefShape && staticStrides == memrefStrides &&
+ dynamicShape.empty() && dynamicStrides.empty()) {
staticShapeAttr = DenseI64ArrayAttr();
staticStridesAttr = DenseI64ArrayAttr();
}
@@ -277,8 +312,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
// if shape and strides are from Memref, we don't need attributes for them
- // to keep the IR print clean.
- if (staticShape == memrefShape && staticStrides == memrefStrides) {
+ // to keep the IR print clean (only do so for full-static case, otherwise
+ // printer would fail trying to print empty array-attr).
+ if (staticShape == memrefShape && staticStrides == memrefStrides &&
+ dynamicShape.empty() && dynamicStrides.empty()) {
staticShapeAttr = DenseI64ArrayAttr();
staticStridesAttr = DenseI64ArrayAttr();
}
@@ -428,16 +465,9 @@ LogicalResult PrefetchNdOp::verify() {
if (!isReadHintOrNone(getL3HintAttr()))
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
- int64_t tDescRank = tdescTy.getRank();
- int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
- int64_t constOffsetSize =
- getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
- if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
- ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
- return emitOpError(
- "Mismatched ranks between offsets and tensor descriptor");
-
- return success();
+ auto tDesc = getTensorDesc();
+ return isValidNdOffset(tDesc, getConstOffsets(), getMixedOffsets().size(),
+ [&]() { return emitOpError(); });
}
//===----------------------------------------------------------------------===//
@@ -553,16 +583,9 @@ LogicalResult LoadNdOp::verify() {
<< " is not consistent with tensor descriptor "
<< tdescTy;
- int64_t tDescRank = tdescTy.getRank();
- int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
- int64_t constOffsetSize =
- getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
- if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
- ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
- return emitOpError(
- "Mismatched ranks between offsets and tensor descriptor");
-
- return success();
+ auto tDesc = getTensorDesc();
+ return isValidNdOffset(tDesc, getConstOffsets(), getMixedOffsets().size(),
+ [&]() { return emitOpError(); });
}
//===----------------------------------------------------------------------===//
@@ -647,16 +670,9 @@ LogicalResult StoreNdOp::verify() {
<< " is not consistent with tensor descriptor "
<< dstTy;
- int64_t tDescRank = dstTy.getRank();
- int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
- int64_t constOffsetSize =
- getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
- if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
- ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
- return emitOpError(
- "Mismatched ranks between offsets and tensor descriptor");
-
- return success();
+ auto tDesc = getTensorDesc();
+ return isValidNdOffset(tDesc, getConstOffsets(), getMixedOffsets().size(),
+ [&]() { return emitOpError(); });
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
index 9908205f07c92..b5fb2c4aa3e27 100644
--- a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
@@ -10,10 +10,10 @@ func.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vecto
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME: %[[SRC]]
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
// CHECK-SAME: boundary_check = false
-// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf32>
+// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8xf32>
// CHECK: return %[[VEC]]
// -----
@@ -29,9 +29,9 @@ func.func @load_2D_vector(%source: memref<8x16x32xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME: %[[SRC]]
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]
// -----
@@ -53,10 +53,10 @@ func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
// CHECK-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]
// -----
@@ -72,9 +72,9 @@ func.func @load_out_of_bounds(%source: memref<7x15xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<7x15xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME: %[[SRC]]
// CHECK-SAME: memref<7x15xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]
// -----
diff --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
index 2c498dcc2a071..57e754f7d7c00 100644
--- a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
@@ -12,10 +12,10 @@ func.func @store_1D_vector(%vec: vector<8xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME: %[[SRC]]
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
// CHECK-SAME: boundary_check = false
-// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf32>
+// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8xf32>
// -----
@@ -31,9 +31,9 @@ func.func @store_2D_vector(%vec: vector<8x16xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME: %[[SRC]]
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
// -----
@@ -55,10 +55,10 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
// CHECK-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
// -----
@@ -74,9 +74,9 @@ func.func @store_out_of_bounds(%vec: vector<8x16xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<7x64xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME: %[[SRC]]
// CHECK-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// CHECK: xegpu.store_nd %[[VEC]], %[...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with some comments
| return success(); | ||
| } | ||
|
|
||
| // Verify that number of offsets matches either the source rank or the tdesc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should simply check that offsets can't have lower rank than tdesc rank. It shoud be fine if it is larger than tdesc rank.
I am not sure that it is common practice for the op validation to validate itself using information from producer op. I would rather checking this in transformation or lowering passes. Any opinion from @adam-smnk ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Op verifier should be contained to the op itself without accessing any external data.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should simply check that offsets can't have lower rank than tdesc rank. It shoud be fine if it is larger than tdesc rank.
Sound reasonable. Just please clearly define how offsets are interpreted.
For example, memref is 4D, descriptor is 2D and there are 3 offsets. I guess they'd apply to the innermost dimensions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
simplified the check to not rely on CreateNdTdescOp.
Also added docs for the new offset syntax
|
Looking at the older Is there any clear benefit in relaxing load offset semantics? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(can be a separate PR)
Could you also update ops' docs with the new offset semantics?
| return success(); | ||
| } | ||
|
|
||
| // Verify that number of offsets matches either the source rank or the tdesc |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should simply check that offsets can't have lower rank than tdesc rank. It shoud be fine if it is larger than tdesc rank.
Sound reasonable. Just please clearly define how offsets are interpreted.
For example, memref is 4D, descriptor is 2D and there are 3 offsets. I guess they'd apply to the innermost dimensions.
Signed-off-by: dchigarev <[email protected]>
Signed-off-by: dchigarev <[email protected]>
Signed-off-by: dchigarev <[email protected]>
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
Signed-off-by: dchigarev <[email protected]>
| // expected-error@+1 {{Mismatched ranks between offsets and tensor descriptor}} | ||
| %2 = xegpu.load_nd %1[0, 0] : !xegpu.tensor_desc<16xf16> -> vector<16xf16> | ||
| return | ||
| } | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Signed-off-by: dchigarev <[email protected]>
Signed-off-by: dchigarev <[email protected]>
Actually this is a great question. The original motivation is that the users want to express a 2D block in the nD tensor with the higher dimensions (n>=2) flattened to the 2nd dimension. After that user only works with the flattened 2D tensor and moves the 2D block within the flatten 2D tensor. So I think that the right solution is what you proposed: As we move the offsets to load_nd, XeGPU user must explicitly insert a memref.subview to flatten the ND tensor to 2D, and then create a 2D tensor descriptor. Allowing 3D+ offsets creates an issue during lowering to XeVM, even we allow it in IR creation time. The tensor descriptor in HW only tracks the stride of innermost dimension so only support 2D block load. Allowing 3D+ offset means we need to either compute the flatten 2D offsets inside the K-loop, or expand the tensor descriptor to track more than 1 stride. This increases complexity and may cause negative performance impact. Instead, I think the right balance between "ease of use" and "performance" is to ask user to do a subview upfront to flatten it upfront. See the discussion here #164701 |
Changes the
VectorToXeGPUpass to generatexegpu.load_nd/store_ndops using new syntax with where offsets are specified at the load/store ops level.In order to support cases with dimension reduction at the
create_nd_tdesclevel (e.g.memref<8x8x16xf16> -> tensor_desc<8x16xf16>I had to tweak the load_nd/store_nd/prefetch_nd verification process to allow the number offsets to missmatch with the tensor descriptor dimension.Why we need to change that?
The new verification logic checks that the number of offsets either matches the source (e.g. memref's) rank or the tensor descriptor rank. Since
TensorDescriptorTypedoesn't carry source shape information (only the shape of the tile) we can only run the verification if thecreate_nd_tensordescoperation is accessible viatdesc.getDefiningOp(). If it's not the case we skip the verification and postpone it untilxegpu-to-xevmpass wherecreate_nd_tensordescmust be available.