Skip to content

Commit bb65e50

Browse files
committed
[water] add wave.permute op
Signed-off-by: Tim Gymnich <[email protected]>
1 parent 28b6807 commit bb65e50

File tree

13 files changed

+499
-25
lines changed

13 files changed

+499
-25
lines changed

water/include/water/Dialect/Wave/IR/WaveInterfaces.h

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ llvm::LogicalResult verifyElementTypesMatch(std::optional<mlir::Location> loc,
146146
// location is provided.
147147
llvm::LogicalResult verifyTypesCompatible(
148148
wave::WaveTensorType lhs, wave::WaveTensorType rhs,
149-
bool includeAddressSpace,
149+
bool includeAddressSpace, bool includeShape = true,
150150
std::optional<mlir::Location> errorLocation = std::nullopt,
151151
llvm::StringRef lhsName = "", llvm::StringRef rhsName = "");
152152

@@ -161,11 +161,10 @@ verifyTypesMatchingDimensions(std::optional<mlir::Location> loc,
161161
llvm::ArrayRef<int> rhsDims);
162162

163163
// Verification logic for the compatible-operands traits. Succeeds if all wave
164-
// tensor-typed operands and results have compatible shapes and, if the
165-
// corresponding flag is set, compatible address spaces.
166-
llvm::LogicalResult
167-
verifyCompatibleOperandsAndResultsOpTrait(mlir::Operation *op,
168-
bool includeAddressSpace);
164+
// tensor-typed operands and results have compatible shapes (unless includeShape
165+
// is false) and, if the corresponding flag is set, compatible address spaces.
166+
llvm::LogicalResult verifyCompatibleOperandsAndResultsOpTrait(
167+
mlir::Operation *op, bool includeAddressSpace, bool includeShape = true);
169168
}; // namespace detail
170169

171170
template <typename OpTy>
@@ -190,6 +189,17 @@ class CompatibleOperandsAndResultsIgnoreSpaceOpTrait
190189
}
191190
};
192191

192+
template <typename OpTy>
193+
class CompatibleOperandsAndResultsIgnoreShapeOpTrait
194+
: public mlir::OpTrait::TraitBase<
195+
OpTy, CompatibleOperandsAndResultsIgnoreShapeOpTrait> {
196+
public:
197+
static llvm::LogicalResult verifyTrait(mlir::Operation *op) {
198+
return detail::verifyCompatibleOperandsAndResultsOpTrait(
199+
op, /*includeAddressSpace=*/true, /*includeShape=*/false);
200+
}
201+
};
202+
193203
//-----------------------------------------------------------------------------
194204
// WaveElementsPerThreadOpInterface
195205
//-----------------------------------------------------------------------------

water/include/water/Dialect/Wave/IR/WaveInterfaces.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,11 @@ def CompatibleOperandsAndResultsIgnoreSpaceOpTrait
138138
let cppNamespace = "::wave";
139139
}
140140

141+
def CompatibleOperandsAndResultsIgnoreShapeOpTrait
142+
: NativeOpTrait<"CompatibleOperandsAndResultsIgnoreShapeOpTrait"> {
143+
let cppNamespace = "::wave";
144+
}
145+
141146
//-----------------------------------------------------------------------------
142147
// WaveInferIndexExprsOpInterface and implementation traits
143148
//-----------------------------------------------------------------------------

water/include/water/Dialect/Wave/IR/WaveOps.td

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,4 +505,34 @@ def CastOp : WaveOp<"cast", [
505505
let hasVerifier = 1;
506506
}
507507

508+
def PermuteOp : WaveOp<"permute", [
509+
DeclareOpInterfaceMethods<WaveInferTypeOpInterface>, CompatibleOperandsAndResultsIgnoreShapeOpTrait,
510+
WaveElementsPerThreadOpInterface, IdentityElementsPerThreadOpTrait,
511+
DeclareOpInterfaceMethods<WaveInferIndexExprsOpInterface>]> {
512+
let summary = "Permute the dimensions of a register-resident tensor";
513+
let description = [{
514+
Reorders the symbolic dimensions of a register-resident tensor according
515+
to the specified target shape. This operation is primarily a semantic
516+
marker that affects how index expressions are transformed during
517+
compilation. At lowering time, the operation is a pass-through since the
518+
actual data layout in registers remains unchanged - only the interpretation
519+
of which dimension each element belongs to changes.
520+
521+
For example, permuting a tensor with shape [B, M, N] to [M, N, B] swaps
522+
the strides associated with the symbolic dimensions in the index
523+
expressions.
524+
}];
525+
let arguments = !con((ins
526+
Arg<WaveTensorInRegister, "Value to permute">:$value,
527+
Arg<WaveSymbolArrayAttr , "Target dimension ordering">:$target_shape
528+
), commonArguments);
529+
let results = (outs
530+
Res<WaveTensorInRegister, "Permuted value">:$result
531+
);
532+
533+
let assemblyFormat =
534+
"$value `,` $target_shape " # commonArgumentsSyntax # " attr-dict `:`"
535+
"functional-type(operands, results)";
536+
}
537+
508538
#endif // WATER_DIALECT_WAVE_WAVEOPS

water/lib/Dialect/Wave/IR/WaveInterfaces.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -375,8 +375,9 @@ llvm::LogicalResult wave::detail::verifyElementTypesMatch(
375375

376376
llvm::LogicalResult wave::detail::verifyTypesCompatible(
377377
wave::WaveTensorType lhs, wave::WaveTensorType rhs,
378-
bool includeAddressSpace, std::optional<Location> errorLocation,
379-
llvm::StringRef lhsName, llvm::StringRef rhsName) {
378+
bool includeAddressSpace, bool includeShape,
379+
std::optional<Location> errorLocation, llvm::StringRef lhsName,
380+
llvm::StringRef rhsName) {
380381
// Fast and cheap path.
381382
if (lhs == rhs)
382383
return success();
@@ -402,6 +403,9 @@ llvm::LogicalResult wave::detail::verifyTypesCompatible(
402403
verifyElementTypesMatch(errorLocation, lhsName, lhs, rhsName, rhs)))
403404
return failure();
404405

406+
if (!includeShape)
407+
return success();
408+
405409
if (!lhs.getFullySpecified() || !rhs.getFullySpecified())
406410
return success();
407411

@@ -422,7 +426,7 @@ llvm::LogicalResult wave::detail::verifyTypesCompatible(
422426
static llvm::LogicalResult
423427
verifyTypeRange(Location loc, TypeRange range,
424428
wave::WaveTensorType referenceType, bool includeAddressSpace,
425-
llvm::StringRef rangeDescriptionPrefix,
429+
bool includeShape, llvm::StringRef rangeDescriptionPrefix,
426430
llvm::StringRef referenceDescription) {
427431
llvm::SmallString<16> rangeDescription(rangeDescriptionPrefix);
428432
for (auto &&[i, type] : llvm::enumerate(range)) {
@@ -435,16 +439,16 @@ verifyTypeRange(Location loc, TypeRange range,
435439
os << i;
436440

437441
if (failed(wave::detail::verifyTypesCompatible(
438-
tensorType, referenceType, includeAddressSpace, loc, os.str(),
439-
referenceDescription))) {
442+
tensorType, referenceType, includeAddressSpace, includeShape, loc,
443+
os.str(), referenceDescription))) {
440444
return llvm::failure();
441445
}
442446
}
443447
return llvm::success();
444448
}
445449

446450
llvm::LogicalResult wave::detail::verifyCompatibleOperandsAndResultsOpTrait(
447-
Operation *op, bool includeAddressSpace) {
451+
Operation *op, bool includeAddressSpace, bool includeShape) {
448452
const llvm::StringLiteral kOperandNamePrefix = "operand #";
449453
const llvm::StringLiteral kResultNamePrefix = "result #";
450454
std::string referenceDescription;
@@ -470,11 +474,12 @@ llvm::LogicalResult wave::detail::verifyCompatibleOperandsAndResultsOpTrait(
470474

471475
if (llvm::failed(verifyTypeRange(op->getLoc(), op->getOperandTypes(),
472476
referenceType, includeAddressSpace,
473-
kOperandNamePrefix, os.str())))
477+
includeShape, kOperandNamePrefix, os.str())))
474478
return llvm::failure();
475479

476480
return verifyTypeRange(op->getLoc(), op->getResultTypes(), referenceType,
477-
includeAddressSpace, kResultNamePrefix, os.str());
481+
includeAddressSpace, includeShape, kResultNamePrefix,
482+
os.str());
478483
}
479484

480485
//-----------------------------------------------------------------------------

0 commit comments

Comments
 (0)