Skip to content

Commit 4d791f0

Browse files
authored
[NFC] Move getTiedArgs into TritonGPU utils (#7277)
Moves `getTiedArgs` into `TritonGPU`/`Utils` for other passes to use.
1 parent 75d27b0 commit 4d791f0

File tree

3 files changed

+36
-32
lines changed

3 files changed

+36
-32
lines changed

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,10 @@ void replaceUsesWithLocalLoad(
272272
// after converting loads into async loads.
273273
bool comesFromLoadOrBlockArg(Value v);
274274

275+
// For structured control flow ops, returns the values associated with the
276+
// `resultIdx`th result.
277+
SmallVector<Value> getTiedArgs(Operation *op, int resultIdx);
278+
275279
} // namespace mlir::triton
276280

277281
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1586,4 +1586,36 @@ bool comesFromLoadOrBlockArg(Value v) {
15861586
isa<LoadOp, DescriptorLoadOp, DescriptorGatherOp>(v.getDefiningOp()));
15871587
}
15881588

1589+
SmallVector<Value> getTiedArgs(Operation *op, int resultIdx) {
1590+
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
1591+
auto iterArg = forOp.getRegionIterArg(resultIdx);
1592+
auto result = forOp.getResult(resultIdx);
1593+
auto yieldVal = forOp.getBody()->getTerminator()->getOperand(resultIdx);
1594+
auto initVal = forOp.getInitArgs()[resultIdx];
1595+
return {iterArg, result, yieldVal, initVal};
1596+
} else if (auto whileOp = dyn_cast<scf::WhileOp>(op)) {
1597+
auto iterArg = whileOp.getBeforeArguments()[resultIdx];
1598+
auto result = whileOp.getResults()[resultIdx];
1599+
auto yieldVal =
1600+
whileOp.getBeforeBody()->getTerminator()->getOperand(resultIdx);
1601+
auto initVal = whileOp.getOperands()[resultIdx];
1602+
return {iterArg, result, iterArg, initVal};
1603+
} else if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
1604+
SmallVector<Value> values;
1605+
for (auto &block : ifOp.getThenRegion().getBlocks()) {
1606+
auto terminator = block.getTerminator();
1607+
if (isa<scf::YieldOp>(terminator))
1608+
values.push_back(terminator->getOperands()[resultIdx]);
1609+
}
1610+
for (auto &block : ifOp.getElseRegion().getBlocks()) {
1611+
auto terminator = block.getTerminator();
1612+
if (isa<scf::YieldOp>(terminator))
1613+
values.push_back(terminator->getOperands()[resultIdx]);
1614+
}
1615+
values.push_back(ifOp->getResults()[resultIdx]);
1616+
return values;
1617+
}
1618+
return {};
1619+
}
1620+
15891621
} // namespace mlir::triton

lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -128,38 +128,6 @@ namespace nvidia_gpu {
128128

129129
namespace {
130130

131-
SmallVector<Value> getTiedArgs(Operation *op, int resultIdx) {
132-
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
133-
auto iterArg = forOp.getRegionIterArg(resultIdx);
134-
auto result = forOp.getResult(resultIdx);
135-
auto yieldVal = forOp.getBody()->getTerminator()->getOperand(resultIdx);
136-
auto initVal = forOp.getInitArgs()[resultIdx];
137-
return {iterArg, result, yieldVal, initVal};
138-
} else if (auto whileOp = dyn_cast<scf::WhileOp>(op)) {
139-
auto iterArg = whileOp.getBeforeArguments()[resultIdx];
140-
auto result = whileOp.getResults()[resultIdx];
141-
auto yieldVal =
142-
whileOp.getBeforeBody()->getTerminator()->getOperand(resultIdx);
143-
auto initVal = whileOp.getOperands()[resultIdx];
144-
return {iterArg, result, iterArg, initVal};
145-
} else if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
146-
SmallVector<Value> values;
147-
for (auto &block : ifOp.getThenRegion().getBlocks()) {
148-
auto terminator = block.getTerminator();
149-
if (isa<scf::YieldOp>(terminator))
150-
values.push_back(terminator->getOperands()[resultIdx]);
151-
}
152-
for (auto &block : ifOp.getElseRegion().getBlocks()) {
153-
auto terminator = block.getTerminator();
154-
if (isa<scf::YieldOp>(terminator))
155-
values.push_back(terminator->getOperands()[resultIdx]);
156-
}
157-
values.push_back(ifOp->getResults()[resultIdx]);
158-
return values;
159-
}
160-
return {};
161-
}
162-
163131
const EncodingInfo *internEncoding(std::unordered_set<EncodingInfo> &encodings,
164132
EncodingInfo info) {
165133
return &*encodings.insert(info).first;

0 commit comments

Comments
 (0)