Skip to content

Commit c2a39f4

Browse files
authored
[RemoveLayoutConversions]: Protect for loop support with env. variable (#5282)
Signed-off-by: Ettore Tiotto <[email protected]>
1 parent 0ed53f4 commit c2a39f4

File tree

4 files changed

+50
-28
lines changed

4 files changed

+50
-28
lines changed

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
5050
"TRITON_INTEL_FAST_MATH",
5151
"TRITON_INTEL_ONE_MATRIX_PER_LOAD_BT",
5252
"TRITON_INTEL_PREDICATED",
53+
"TRITON_INTEL_REMOVELAYOUTCONVERSION_SUPPORT_FOR_LOOP",
5354
// clang-format on
5455
};
5556

test/TritonIntelGPU/combine.mlir

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -tritonintelgpu-remove-layout-conversions 2>&1 | FileCheck %s
1+
// RUN: env TRITON_INTEL_REMOVELAYOUTCONVERSION_SUPPORT_FOR_LOOP=0 triton-opt %s -split-input-file -allow-unregistered-dialect -tritonintelgpu-remove-layout-conversions 2>&1 | FileCheck --check-prefixes=CHECK %s
2+
// RUN: env TRITON_INTEL_REMOVELAYOUTCONVERSION_SUPPORT_FOR_LOOP=1 triton-opt %s -split-input-file -allow-unregistered-dialect -tritonintelgpu-remove-layout-conversions 2>&1 | FileCheck --check-prefixes=CHECK,FOR-SUPPORT %s
23

34
#layout0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
45
#layout1 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
@@ -351,7 +352,7 @@ tt.func @loop(%arg0: !tt.ptr<f32>, %arg1: i32, %arg2: !tt.ptr<f32>, %arg3: i32,
351352
// CHECK: scf.for
352353
// CHECK-NOT: ttg.convert_layout
353354
// CHECK: scf.if
354-
// CHECK: ttg.convert_layout
355+
// FOR-SUPPORT: ttg.convert_layout
355356
// CHECK: scf.yield
356357
// CHECK-NEXT: else
357358
// CHECK-NEXT: scf.yield

third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
2626
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
2727
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
28+
#include "triton/Tools/Sys/GetEnv.hpp"
2829
#include "llvm/ADT/TypeSwitch.h"
2930
#include <deque>
3031

@@ -1101,6 +1102,10 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
11011102
DenseMap<Value, Attribute> &layout,
11021103
ConvertLayoutOp convertOp,
11031104
IRMapping &mapping) {
1105+
std::optional<bool> enableForLoopSupport =
1106+
mlir::triton::tools::isEnvValueBool(mlir::triton::tools::getStrEnv(
1107+
"TRITON_INTEL_REMOVELAYOUTCONVERSION_SUPPORT_FOR_LOOP"));
1108+
11041109
SetVector<Operation *> opsToRewrite;
11051110
// Keep track of yield operands that need to be duplicated.
11061111
DenseMap<Operation *, SmallVector<int>> yieldOperandsMap;
@@ -1126,12 +1131,13 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
11261131
opsToRewrite.insert(ifOp.elseYield().getOperation());
11271132
yieldOperandsMap[ifOp.elseYield()].push_back(operandIdx);
11281133
}
1129-
if (auto forOp = v.getDefiningOp<scf::ForOp>()) {
1130-
unsigned operandIdx = cast<OpResult>(v).getResultNumber();
1131-
auto yieldOp = forOp.getBody()->getTerminator();
1132-
yieldOperandsMap[yieldOp].push_back(operandIdx);
1133-
opsToRewrite.insert(yieldOp);
1134-
}
1134+
if (enableForLoopSupport)
1135+
if (auto forOp = v.getDefiningOp<scf::ForOp>()) {
1136+
unsigned operandIdx = cast<OpResult>(v).getResultNumber();
1137+
auto yieldOp = forOp.getBody()->getTerminator();
1138+
yieldOperandsMap[yieldOp].push_back(operandIdx);
1139+
opsToRewrite.insert(yieldOp);
1140+
}
11351141
} else {
11361142
BlockArgument blockArg = cast<BlockArgument>(v);
11371143
Operation *parentOp = blockArg.getOwner()->getParentOp();
@@ -1155,17 +1161,19 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
11551161
IRRewriter builder(slice.begin()->getContext());
11561162
for (Operation *op : opsToRewrite) {
11571163
if (auto forOp = dyn_cast<scf::ForOp>(op)) {
1158-
// Construct the new initialization argument by adding yielded operands
1159-
// that have been remapped.
1160-
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
1161-
auto yieldOperands = llvm::to_vector(yieldOp.getOperands());
1162-
SmallVector<int> operandsToRewrite = yieldOperandsMap[yieldOp];
1163-
std::sort(operandsToRewrite.begin(), operandsToRewrite.end());
11641164
SmallVector<Value> newOperands;
1165-
for (int operandIdx : operandsToRewrite) {
1166-
Value yieldOperand = yieldOp.getOperand(operandIdx);
1167-
if (mapping.contains(yieldOperand))
1168-
newOperands.push_back(mapping.lookup(yieldOperand));
1165+
if (enableForLoopSupport) {
1166+
// Construct the new initialization argument by adding yielded operands
1167+
// that have been remapped.
1168+
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
1169+
auto yieldOperands = llvm::to_vector(yieldOp.getOperands());
1170+
SmallVector<int> operandsToRewrite = yieldOperandsMap[yieldOp];
1171+
std::sort(operandsToRewrite.begin(), operandsToRewrite.end());
1172+
for (int operandIdx : operandsToRewrite) {
1173+
Value yieldOperand = yieldOp.getOperand(operandIdx);
1174+
if (mapping.contains(yieldOperand))
1175+
newOperands.push_back(mapping.lookup(yieldOperand));
1176+
}
11691177
}
11701178

11711179
// Keep a mapping of the operands index to the new operands index.
@@ -1183,17 +1191,19 @@ void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
11831191
scf::ForOp newForOp = replaceForOpWithNewSignature(
11841192
builder, forOp, newOperands, replacements);
11851193

1186-
// Add rematerializations for loop results in the slice.
1187-
unsigned oldIdx = 0;
1188-
unsigned newIdx = forOp.getNumResults();
1189-
for (auto res : forOp.getResults()) {
1190-
if (slice.count(res)) {
1191-
mapping.map(forOp.getResult(oldIdx), newForOp.getResult(newIdx));
1192-
addRematValue(forOp.getResult(oldIdx), layout[res],
1193-
newForOp.getResult(newIdx));
1194-
++newIdx;
1194+
if (enableForLoopSupport) {
1195+
// Add rematerializations for loop results in the slice.
1196+
unsigned oldIdx = 0;
1197+
unsigned newIdx = forOp.getNumResults();
1198+
for (auto res : forOp.getResults()) {
1199+
if (slice.count(res)) {
1200+
mapping.map(forOp.getResult(oldIdx), newForOp.getResult(newIdx));
1201+
addRematValue(forOp.getResult(oldIdx), layout[res],
1202+
newForOp.getResult(newIdx));
1203+
++newIdx;
1204+
}
1205+
++oldIdx;
11951206
}
1196-
++oldIdx;
11971207
}
11981208

11991209
deadOps.push_back(forOp.getOperation());

third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h"
1616
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1717
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
18+
#include "triton/Tools/Sys/GetEnv.hpp"
1819

1920
#include <optional>
2021

@@ -186,6 +187,10 @@ LogicalResult getConvertBackwardSlice(
186187
DenseSet<std::pair<OpOperand *, Attribute>> seen;
187188
SmallVector<std::pair<OpOperand *, Attribute>> queue;
188189

190+
std::optional<bool> enableForLoopSupport =
191+
mlir::triton::tools::isEnvValueBool(mlir::triton::tools::getStrEnv(
192+
"TRITON_INTEL_REMOVELAYOUTCONVERSION_SUPPORT_FOR_LOOP"));
193+
189194
auto enqueue = [&](OpOperand &operand, Attribute encoding) {
190195
auto x = std::make_pair(&operand, encoding);
191196
if (!seen.insert(x).second) {
@@ -217,6 +222,11 @@ LogicalResult getConvertBackwardSlice(
217222
if (!isTensorOrTensorPointerType(currentValue.getType()))
218223
continue;
219224

225+
// Skip propagating through for op results for now.
226+
// TODO: enable this based on needs.
227+
if (!enableForLoopSupport && currentValue.getDefiningOp<scf::ForOp>())
228+
return failure();
229+
220230
if (failed(updateLayout(currentValue, encoding)))
221231
return failure();
222232

0 commit comments

Comments
 (0)