Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ namespace NVVM {
enum class PTXRegisterMod {
/// Read register with no modifier
Read = 0,
/// Read register with '+' modifier
/// Write register with '=' modifier
Write = 2,
/// Read register with '=' modifier.
/// Note that, this is not natively supported by LLVM, but it is possible to
/// set read and write for the same operand.
/// ReadWrite register with '+' modifier.
/// Note that, this is not natively supported by LLVM, the Interface does
/// mapping
ReadWrite = 1,
};

Expand Down Expand Up @@ -67,13 +67,19 @@ class PtxBuilder {
SmallVector<Value> ptxOperands;
// Register constraints (read, write, readwrite) and register data types
std::string registerConstraints;

// Modifiers
SmallVector<PTXRegisterMod> registerModifiers;
// Has return value as write-only or read-write
bool hasResult = false;
// Indicates if the Op will handle the register mapping manually.
bool needsManualRegisterMapping = false;

public:
/// Single constructor that only initializes members.
PtxBuilder(Operation *op, PatternRewriter &rewriter)
: interfaceOp(op), rewriter(rewriter) {}
PtxBuilder(Operation *op, PatternRewriter &rewriter,
bool needsManualRegisterMapping = false)
: interfaceOp(op), rewriter(rewriter),
needsManualRegisterMapping(needsManualRegisterMapping) {}

/// Add an operand with the read/write input type.
void insertValue(Value v, PTXRegisterMod itype = PTXRegisterMod::Read);
Expand All @@ -87,6 +93,16 @@ class PtxBuilder {
void buildAndReplaceOp();
};

/// Count the number of placeholder variables such as {$r}, {$w}, {$rw} in the
/// PTX code.
void countPlaceholderNumbers(StringRef ptxCode,
llvm::SmallDenseSet<unsigned> &seenRW,
llvm::SmallDenseSet<unsigned> &seenW,
llvm::SmallDenseSet<unsigned> &seenR,
llvm::SmallVectorImpl<unsigned> &rwNums,
llvm::SmallVectorImpl<unsigned> &wNums,
llvm::SmallVectorImpl<unsigned> &rNums);

} // namespace NVVM
} // namespace mlir

Expand Down
13 changes: 8 additions & 5 deletions mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -124,19 +124,21 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
following this order:
1) Adds results
2) Adds operands
3) Adds attributes
3) Adds attributes
Returns true if the OP is going to do register mapping itself
}],
/*retType=*/"void",
/*retType=*/"bool",
/*methodName=*/"getAsmValues",
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
"llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>&" : $asmValues),
"llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>&" : $asmValues
),
/*methodBody=*/"",
/*defaultImpl=*/ [{
mlir::Operation* op = $_op;

// Step 1. Add results
for (auto val : op->getResults())
asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Write});
for (auto val : op->getResults())
asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Write});

// Step 2. Add operands
for (auto val : op->getOperands())
Expand All @@ -149,6 +151,7 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
}
}
return false; // No manual mapping needed
}]
>
];
Expand Down
22 changes: 14 additions & 8 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -315,16 +315,19 @@ def NVVM_InlinePtxOp : NVVM_Op<"inline_ptx",
}];

let arguments = (ins Variadic<AnyType>:$readOnlyArgs,
Variadic<AnyType>:$readWriteArgs,
StrAttr:$ptxCode,
PtxPredicate:$predicate);

let results = (outs Variadic<AnyType>:$writeOnlyArgs);

let assemblyFormat = [{
$ptxCode `(` $readOnlyArgs `)`
(`,` `predicate` `=` $predicate^)? attr-dict
`:` type(operands)
(`->` type($writeOnlyArgs)^)?

let assemblyFormat = [{
$ptxCode
( `ro` `(` $readOnlyArgs^ `:` type($readOnlyArgs) `)` )?
( `rw` `(` $readWriteArgs^ `:` type($readWriteArgs) `)` )?
(`,` `predicate` `=` $predicate^)?
attr-dict
( `->` type($writeOnlyArgs)^ )?
}];

let extraClassDefinition = [{
Expand All @@ -333,6 +336,10 @@ def NVVM_InlinePtxOp : NVVM_Op<"inline_ptx",
return std::string(ptxInstStr.data());
}
}];

let extraClassDeclaration = [{
bool getAsmValues(RewriterBase &, llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &);
}];
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -3027,8 +3034,7 @@ def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async",
let hasVerifier = 1;

let extraClassDeclaration = [{
void getAsmValues(RewriterBase &rewriter,
llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &asmValues);
bool getAsmValues(RewriterBase &, llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &);
}];
}

Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ struct PtxLowering

SmallVector<std::pair<Value, PTXRegisterMod>> asmValues;
LDBG() << op.getPtx();
PtxBuilder generator(op, rewriter);

op.getAsmValues(rewriter, asmValues);
bool needsManualMapping = op.getAsmValues(rewriter, asmValues);
PtxBuilder generator(op, rewriter, needsManualMapping);
for (auto &[asmValue, modifier] : asmValues) {
LDBG() << asmValue << "\t Modifier : " << modifier;
generator.insertValue(asmValue, modifier);
Expand Down
Loading