Skip to content

Commit 4b20620

Browse files
[NFI] Clean up useTileLoadLinearLayout (#5232)
This PR removes the `useTileLoadLinearLayout` configuration option and related infrastructure from the Intel Triton GPU backend. The change simplifies the codebase by removing a conditional feature flag that was previously enabled by default. Signed-off-by: Whitney Tsang <[email protected]>
1 parent 1a1d82c commit 4b20620

File tree

8 files changed

+24
-68
lines changed

8 files changed

+24
-68
lines changed

python/triton/knobs.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,6 @@ class intel_knobs(base_knobs):
549549

550550
dump_shader_info: env_bool = env_bool("TRITON_INTEL_ENABLE_IGC_SHADER_DUMP", False)
551551
gen_native_code: env_bool = env_bool("TRITON_XPU_GEN_NATIVE_CODE", False)
552-
tile_load_ll: env_bool = env_bool("TRITON_XPU_ENABLE_TILE_LOAD_LINEAR_LAYOUT", True)
553552
opt_reduction_locality: env_bool = env_bool("TRITON_INTEL_OPTIMIZE_REDUCTION_LOCALITY", False)
554553
disable_igc_opt: env_bool = env_bool("TRITON_INTEL_DISABLE_IGC_OPT", False)
555554

third_party/intel/backend/compiler.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ class XPUOptions:
4040
backend_name: str = 'intel'
4141
sanitize_overflow: bool = False
4242
generate_native_code: bool = False
43-
enable_tile_load_linear_layout: bool = True
4443
arch: str = None
4544
# FIXME: enable for XPU: https://github.com/intel/intel-xpu-backend-for-triton/issues/4954
4645
instrumentation_mode: str = ""
@@ -133,7 +132,6 @@ def parse_target(self, tgt_prop) -> dict:
133132
def parse_options(self, opts) -> Any:
134133
args = {k: opts[k] for k in XPUOptions.__dataclass_fields__.keys() if k in opts}
135134
args["allow_fp8e4nv"] = True
136-
args["enable_tile_load_linear_layout"] = knobs.intel.tile_load_ll
137135
return XPUOptions(**args)
138136

139137
def pack_metadata(self, metadata):
@@ -298,7 +296,7 @@ def make_llir(src, metadata, options):
298296
# instrumentation point here so we can override IRs above (e.g., ttir and ttgir)
299297
if XPUBackend.instrumentation:
300298
XPUBackend.instrumentation.patch("ttgpuir_to_llvmir", pm, mod.context)
301-
intel.passes.ttgpuir.add_to_llvmir(pm, options.enable_tile_load_linear_layout)
299+
intel.passes.ttgpuir.add_to_llvmir(pm)
302300
intel.passes.ttgpuir.add_gen_to_llvm(pm)
303301
passes.common.add_canonicalizer(pm)
304302
intel.passes.ttgpuir.add_rewrite_stack_ptr(pm)

third_party/intel/include/TritonIntelGPUToLLVM/Passes.td

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,6 @@ def ConvertTritonIntelGPUToLLVM
2121
"mlir::triton::TritonDialect",
2222
"mlir::triton::gpu::TritonGPUDialect",
2323
"mlir::triton::TritonGEN::TritonGENDialect"];
24-
let options = [
25-
Option<"useTileLoadLinearLayout", "use_tile_load_linear_layout",
26-
"bool", /*default*/"true",
27-
"Use linear layouts to generate the tile load sizes and offsets">
28-
];
2924
}
3025

3126
#endif // TRITONINTELGPU_CONVERSION_PASSES

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 15 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,10 +1012,9 @@ struct LoadOpToBlockIOConversion
10121012
LoadOpToBlockIOConversion(
10131013
LLVMTypeConverter &converter, const triton::intel::TargetInfo &targetInfo,
10141014
const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass,
1015-
PatternBenefit benefit, bool useTileLoadLinearLayout)
1015+
PatternBenefit benefit)
10161016
: ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>(converter, benefit),
1017-
BlockIOConversionBase(targetInfo, axisAnalysisPass),
1018-
useTileLoadLinearLayout(useTileLoadLinearLayout) {}
1017+
BlockIOConversionBase(targetInfo, axisAnalysisPass) {}
10191018

10201019
LogicalResult
10211020
rewriteTensorPointerLoad(triton::LoadOp op, OpAdaptor adaptor,
@@ -1540,18 +1539,11 @@ struct LoadOpToBlockIOConversion
15401539
// Disable building the load layout if we are not going to use it. Building
15411540
// the layout manually can cause an error which would abort the pass
15421541
// pipeline and block us from getting debug info.
1543-
if (useTileLoadLinearLayout) {
1544-
// add the bases to the map and replace the tile layout with the new
1545-
// layout
1546-
bases[kLoad] = newLoadBases;
1547-
tileLayout = LinearLayout(bases, outDims,
1548-
/*requiredSurjective=*/false);
1549-
} else {
1550-
// when linear layouts are disabled generate a single load, so we can have
1551-
// some reference for linear layout output without generating a layout
1552-
// that could abort the pass pipeline
1553-
tileLayout *= LinearLayout::identity1D(1, kLoad, dimOuterStr);
1554-
}
1542+
// add the bases to the map and replace the tile layout with the new
1543+
// layout
1544+
bases[kLoad] = newLoadBases;
1545+
tileLayout = LinearLayout(bases, outDims,
1546+
/*requiredSurjective=*/false);
15551547

15561548
LLVM_DEBUG({
15571549
llvm::dbgs() << "Block load tile layout after adding loads: "
@@ -1657,33 +1649,19 @@ struct LoadOpToBlockIOConversion
16571649
llvm::dbgs() << "y offset: "
16581650
<< outer * repOuterStride + rep * repStride << "\n";
16591651
});
1660-
if (useTileLoadLinearLayout) {
1661-
offsetY = b.add(b.mul(outerDimWarpId, b.i32_val(warpOuterStride)),
1662-
b.i32_val(layoutOffsetY));
1663-
offsetX = b.i32_val(layoutOffsetX);
1664-
} else {
1665-
offsetY =
1666-
b.add(b.mul(outerDimWarpId, b.i32_val(warpOuterStride)),
1667-
b.i32_val(outer * repOuterStride + rep * repStride));
1668-
offsetX = b.i32_val(k * repKStride);
1669-
}
1652+
offsetY = b.add(b.mul(outerDimWarpId, b.i32_val(warpOuterStride)),
1653+
b.i32_val(layoutOffsetY));
1654+
offsetX = b.i32_val(layoutOffsetX);
16701655
} break;
16711656
case DpasEncodingAttr::OpIdx::OperandB: {
16721657
LLVM_DEBUG({
16731658
llvm::dbgs() << "x offset: "
16741659
<< outer * repOuterStride + rep * repStride << "\n";
16751660
llvm::dbgs() << "y offset: " << k * repKStride << "\n";
16761661
});
1677-
if (useTileLoadLinearLayout) {
1678-
offsetX = b.add(b.mul(outerDimWarpId, b.i32_val(warpOuterStride)),
1679-
b.i32_val(layoutOffsetX));
1680-
offsetY = b.i32_val(layoutOffsetY);
1681-
} else {
1682-
offsetX =
1683-
b.add(b.mul(outerDimWarpId, b.i32_val(warpOuterStride)),
1684-
b.i32_val(outer * repOuterStride + rep * repStride));
1685-
offsetY = b.i32_val(k * repKStride);
1686-
}
1662+
offsetX = b.add(b.mul(outerDimWarpId, b.i32_val(warpOuterStride)),
1663+
b.i32_val(layoutOffsetX));
1664+
offsetY = b.i32_val(layoutOffsetY);
16871665
} break;
16881666
case DpasEncodingAttr::OpIdx::OperandC: {
16891667
llvm_unreachable("unexpected OpIdx::OperandC");
@@ -3164,9 +3142,6 @@ struct LoadOpToBlockIOConversion
31643142

31653143
return success();
31663144
}
3167-
3168-
private:
3169-
bool useTileLoadLinearLayout;
31703145
};
31713146

31723147
struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
@@ -4258,14 +4233,11 @@ void mlir::triton::intel::populateLoadStoreOpToLLVMPatterns(
42584233
LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo,
42594234
RewritePatternSet &patterns,
42604235
const intel::ModuleAxisInfoAnalysis &axisInfoAnalysis,
4261-
PatternBenefit benefit, bool useTileLoadLinearLayout) {
4236+
PatternBenefit benefit) {
42624237
patterns.add<AtomicCASOpConversion, AtomicRMWOpConversion, LoadOpConversion,
42634238
StoreOpConversion, PrefetchOpConversion>(
42644239
typeConverter, targetInfo, axisInfoAnalysis, benefit);
42654240
// BlockIO is more efficient than gather load or scatter store.
4266-
patterns.add<LoadOpToBlockIOConversion>(
4267-
typeConverter, targetInfo, axisInfoAnalysis, benefit.getBenefit() + 2,
4268-
useTileLoadLinearLayout);
4269-
patterns.add<StoreOpToBlockIOConversion>(
4241+
patterns.add<LoadOpToBlockIOConversion, StoreOpToBlockIOConversion>(
42704242
typeConverter, targetInfo, axisInfoAnalysis, benefit.getBenefit() + 2);
42714243
}

third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ void populateFp4ToFpToLLVMPatterns(LLVMTypeConverter &typeConverter,
6161
void populateLoadStoreOpToLLVMPatterns(
6262
LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo,
6363
RewritePatternSet &patterns, const ModuleAxisInfoAnalysis &axisInfoAnalysis,
64-
PatternBenefit benefit, bool useTileLoadLinearLayout);
64+
PatternBenefit benefit);
6565

6666
void populateTensorPtrOpsToLLVMPatterns(LLVMTypeConverter &typeConverter,
6767
RewritePatternSet &patterns,

third_party/intel/lib/TritonIntelGPUToLLVM/PipelineManager.h

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,8 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
184184
/// block pointers or not.
185185
class TritonGPUToLLVMPipelineManager {
186186
public:
187-
TritonGPUToLLVMPipelineManager(ModuleOp &mod, MLIRContext *ctx,
188-
bool useTileLoadLinearLayout)
189-
: mod(mod), ctx(ctx), useTileLoadLinearLayout(useTileLoadLinearLayout) {}
187+
TritonGPUToLLVMPipelineManager(ModuleOp &mod, MLIRContext *ctx)
188+
: mod(mod), ctx(ctx) {}
190189

191190
/// Populate the conversion pipeline for function operations.
192191
void populateFunctionConversionPatterns(
@@ -213,9 +212,8 @@ class TritonGPUToLLVMPipelineManager {
213212
intel::populateDotOpToLLVMPatterns(typeConverter, patterns, benefit);
214213
intel::populateElementwiseOpToLLVMPatterns(
215214
typeConverter, patterns, axisInfoAnalysis, targetInfo, benefit);
216-
intel::populateLoadStoreOpToLLVMPatterns(typeConverter, targetInfo,
217-
patterns, axisInfoAnalysis,
218-
benefit, useTileLoadLinearLayout);
215+
intel::populateLoadStoreOpToLLVMPatterns(
216+
typeConverter, targetInfo, patterns, axisInfoAnalysis, benefit);
219217
intel::populateReduceOpToLLVMPatterns(typeConverter, patterns, targetInfo,
220218
benefit);
221219
mlir::triton::populateScanOpToLLVMPatterns(typeConverter, patterns,
@@ -259,8 +257,6 @@ class TritonGPUToLLVMPipelineManager {
259257
private:
260258
ModuleOp &mod;
261259
MLIRContext *ctx;
262-
263-
bool useTileLoadLinearLayout = true;
264260
};
265261

266262
} // namespace mlir::triton::intel

third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,6 @@ struct ConvertTritonGPUToLLVM
6060
: public triton::gpu::intel::impl::ConvertTritonIntelGPUToLLVMBase<
6161
ConvertTritonGPUToLLVM> {
6262
using ConvertTritonIntelGPUToLLVMBase::ConvertTritonIntelGPUToLLVMBase;
63-
ConvertTritonGPUToLLVM() = default;
64-
ConvertTritonGPUToLLVM(bool useTileLoadLinearLayout) {
65-
this->useTileLoadLinearLayout = useTileLoadLinearLayout;
66-
}
6763

6864
void getDependentDialects(DialectRegistry &registry) const override {
6965
registry.insert<LLVM::LLVMDialect, TritonGEN::TritonGENDialect,
@@ -75,7 +71,7 @@ struct ConvertTritonGPUToLLVM
7571
ModuleOp mod = getOperation();
7672

7773
mlir::triton::intel::TritonGPUToLLVMPipelineManager pipelineManager(
78-
mod, context, useTileLoadLinearLayout);
74+
mod, context);
7975
mlir::LowerToLLVMOptions option(context);
8076
auto targetInfo = mlir::triton::intel::createTargetInfo(mod);
8177
TritonIntelGPUToLLVMTypeConverter typeConverter(context, option,

third_party/intel/triton_xpu.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ void init_triton_intel_passes_ttir(py::module &&m) {
6060
}
6161

6262
void init_triton_intel_passes_ttgpuir(py::module &&m) {
63-
ADD_PASS_OPTION_WRAPPER_1(
64-
"add_to_llvmir", gpu::intel::createConvertTritonIntelGPUToLLVM, bool);
63+
ADD_PASS_WRAPPER_0("add_to_llvmir",
64+
gpu::intel::createConvertTritonIntelGPUToLLVM);
6565
ADD_PASS_WRAPPER_0("add_gen_to_llvm", createConvertTritonGENToLLVM);
6666
ADD_PASS_WRAPPER_0("add_accelerate_matmul",
6767
gpu::intel::createTritonIntelGPUAccelerateMatmul);

0 commit comments

Comments
 (0)