Skip to content

Commit e04202b

Browse files
committed
generalize 'offsets-check'
Signed-off-by: dchigarev <[email protected]>
1 parent 8581183 commit e04202b

File tree

2 files changed

+39
-89
lines changed

2 files changed

+39
-89
lines changed

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 39 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,39 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
121121
return success();
122122
}
123123

124+
// Verify that number of offsets matches either the source rank or the tdesc
125+
// rank.
126+
static LogicalResult
127+
isValidNdOffset(TypedValue<TensorDescType> tDesc,
128+
std::optional<llvm::ArrayRef<long int>> constOffsets,
129+
int64_t offsetSize,
130+
function_ref<InFlightDiagnostic()> emitError) {
131+
if (auto createTDescOp = tDesc.getDefiningOp<CreateNdDescOp>()) {
132+
// If CreateNdDescOp is available, we can further
133+
// check the offsets rank against the source rank.
134+
auto staticSource = createTDescOp.getConstShapeAttr();
135+
int64_t sourceRank;
136+
if (!staticSource || staticSource.empty()) {
137+
auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
138+
sourceRank = sourceTy.getRank();
139+
} else
140+
sourceRank = staticSource.size();
141+
142+
int64_t constOffsetSize = constOffsets ? constOffsets->size() : 0;
143+
auto tDescRank = tDesc.getType().getRank();
144+
bool sourceRankMismatch =
145+
((offsetSize != 0) && (offsetSize != sourceRank)) ||
146+
((constOffsetSize != 0) && (constOffsetSize != sourceRank));
147+
bool tdescRankMismatch =
148+
((offsetSize != 0) && (offsetSize != tDescRank)) ||
149+
((constOffsetSize != 0) && (constOffsetSize != tDescRank));
150+
if (sourceRankMismatch && tdescRankMismatch)
151+
return emitError() << "Offsets rank must match either the source or the "
152+
"TensorDesc rank.";
153+
}
154+
return success();
155+
}
156+
124157
static LogicalResult
125158
isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
126159
VectorType valueTy, int64_t chunkSize,
@@ -433,33 +466,8 @@ LogicalResult PrefetchNdOp::verify() {
433466
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
434467

435468
auto tDesc = getTensorDesc();
436-
if (auto createTDescOp = tDesc.getDefiningOp<CreateNdDescOp>()) {
437-
// If CreateNdDescOp is available, we can further
438-
// check the offsets rank against the source rank.
439-
auto staticSource = createTDescOp.getConstShapeAttr();
440-
int64_t sourceRank;
441-
if (!staticSource || staticSource.empty()) {
442-
auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
443-
sourceRank = sourceTy.getRank();
444-
} else
445-
sourceRank = staticSource.size();
446-
447-
int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
448-
int64_t constOffsetSize =
449-
getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
450-
auto tDescRank = tdescTy.getRank();
451-
bool sourceRankMismatch =
452-
((offsetSize != 0) && (offsetSize != sourceRank)) ||
453-
((constOffsetSize != 0) && (constOffsetSize != sourceRank));
454-
bool tdescRankMismatch =
455-
((offsetSize != 0) && (offsetSize != tDescRank)) ||
456-
((constOffsetSize != 0) && (constOffsetSize != tDescRank));
457-
if (sourceRankMismatch && tdescRankMismatch)
458-
return emitOpError(
459-
"Offsets rank must match either the source or the TensorDesc rank.");
460-
}
461-
462-
return success();
469+
return isValidNdOffset(tDesc, getConstOffsets(), getMixedOffsets().size(),
470+
[&]() { return emitOpError(); });
463471
}
464472

465473
//===----------------------------------------------------------------------===//
@@ -576,33 +584,8 @@ LogicalResult LoadNdOp::verify() {
576584
<< tdescTy;
577585

578586
auto tDesc = getTensorDesc();
579-
if (auto createTDescOp = tDesc.getDefiningOp<CreateNdDescOp>()) {
580-
// If CreateNdDescOp is available, we can further
581-
// check the offsets rank against the source rank.
582-
auto staticSource = createTDescOp.getConstShapeAttr();
583-
int64_t sourceRank;
584-
if (!staticSource || staticSource.empty()) {
585-
auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
586-
sourceRank = sourceTy.getRank();
587-
} else
588-
sourceRank = staticSource.size();
589-
590-
int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
591-
int64_t constOffsetSize =
592-
getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
593-
auto tDescRank = tdescTy.getRank();
594-
bool sourceRankMismatch =
595-
((offsetSize != 0) && (offsetSize != sourceRank)) ||
596-
((constOffsetSize != 0) && (constOffsetSize != sourceRank));
597-
bool tdescRankMismatch =
598-
((offsetSize != 0) && (offsetSize != tDescRank)) ||
599-
((constOffsetSize != 0) && (constOffsetSize != tDescRank));
600-
if (sourceRankMismatch && tdescRankMismatch)
601-
return emitOpError(
602-
"Offsets rank must match either the source or the TensorDesc rank.");
603-
}
604-
605-
return success();
587+
return isValidNdOffset(tDesc, getConstOffsets(), getMixedOffsets().size(),
588+
[&]() { return emitOpError(); });
606589
}
607590

608591
//===----------------------------------------------------------------------===//
@@ -688,33 +671,8 @@ LogicalResult StoreNdOp::verify() {
688671
<< dstTy;
689672

690673
auto tDesc = getTensorDesc();
691-
if (auto createTDescOp = tDesc.getDefiningOp<CreateNdDescOp>()) {
692-
// If CreateNdDescOp is available, we can further
693-
// check the offsets rank against the source rank.
694-
auto staticSource = createTDescOp.getConstShapeAttr();
695-
int64_t sourceRank;
696-
if (!staticSource || staticSource.empty()) {
697-
auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
698-
sourceRank = sourceTy.getRank();
699-
} else
700-
sourceRank = staticSource.size();
701-
702-
int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
703-
int64_t constOffsetSize =
704-
getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
705-
auto tDescRank = dstTy.getRank();
706-
bool sourceRankMismatch =
707-
((offsetSize != 0) && (offsetSize != sourceRank)) ||
708-
((constOffsetSize != 0) && (constOffsetSize != sourceRank));
709-
bool tdescRankMismatch =
710-
((offsetSize != 0) && (offsetSize != tDescRank)) ||
711-
((constOffsetSize != 0) && (constOffsetSize != tDescRank));
712-
if (sourceRankMismatch && tdescRankMismatch)
713-
return emitOpError(
714-
"Offsets rank must match either the source or the TensorDesc rank.");
715-
}
716-
717-
return success();
674+
return isValidNdOffset(tDesc, getConstOffsets(), getMixedOffsets().size(),
675+
[&]() { return emitOpError(); });
718676
}
719677

720678
//===----------------------------------------------------------------------===//

mlir/test/Dialect/XeGPU/invalid.mlir

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -157,14 +157,6 @@ func.func @subgroup_load_nd_offset_3(%src: memref<4x8x16xf16>, %x : index) {
157157
return
158158
}
159159

160-
// -----
161-
func.func @subgroup_load_nd_offset_4(%src: memref<4x8x16xf16>, %x : index) {
162-
%3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
163-
// expected-error@+1 {{Offsets rank must match either the source or the TensorDesc rank.}}
164-
%5 = xegpu.load_nd %3[0] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
165-
return
166-
}
167-
168160
// -----
169161
func.func @load_nd_layout(%src: memref<24x32xf32>) {
170162
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16xf32>

0 commit comments

Comments
 (0)