Skip to content

Commit b80f5dd

Browse files
authored
[Gluon] Add AutoLayout for backward layout inference (#7447)
This adds `ttgl.AutoLayout` which lowers to the new `gluon::AutoEncodingAttr` encoding. In the language, we allow this encoding to propagate through forward type inference using a custom layout inference interface which always returns auto regardless of the operation. So e.g. `sum(a, 1)` where `a.layout` is auto layout will return auto layout, not `SliceLayout(1, AutoLayout())`. This also adds the gluon-specific `--gluon-resolve-auto-encodings` pass, which finds layout conversion from auto encoding to a concrete encoding and propagates the concrete encoding through the graph. The pass will error out if there are any conflicting inferences, rather than duplicating computations as the backward rematerialization pass does. It also errors out if the inference fails for any reason. Note that the errors are really bad at the moment. Will need to fix in a follow-up.
1 parent 586d998 commit b80f5dd

File tree

28 files changed

+671
-26
lines changed

28 files changed

+671
-26
lines changed

bin/RegisterTritonDialects.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h"
55
#include "third_party/nvidia/include/Dialect/NVWS/IR/Dialect.h"
66
#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h"
7+
#include "triton/Dialect/Gluon/Transforms/Passes.h"
78
#include "triton/Dialect/Triton/IR/Dialect.h"
89
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
910
#include "triton/Dialect/TritonInstrument/IR/Dialect.h"
@@ -50,6 +51,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
5051
mlir::triton::gpu::registerTritonGPUPasses();
5152
mlir::triton::nvidia_gpu::registerTritonNvidiaGPUPasses();
5253
mlir::triton::instrument::registerTritonInstrumentPasses();
54+
mlir::triton::gluon::registerGluonPasses();
5355
mlir::test::registerTestAliasPass();
5456
mlir::test::registerTestAlignmentPass();
5557
mlir::test::registerAMDTestAlignmentPass();
@@ -105,5 +107,6 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
105107
mlir::gpu::GPUDialect, mlir::LLVM::LLVMDialect, mlir::NVVM::NVVMDialect,
106108
mlir::triton::nvgpu::NVGPUDialect, mlir::triton::nvws::NVWSDialect,
107109
mlir::triton::amdgpu::TritonAMDGPUDialect,
108-
mlir::triton::proton::ProtonDialect, mlir::ROCDL::ROCDLDialect>();
110+
mlir::triton::proton::ProtonDialect, mlir::ROCDL::ROCDLDialect,
111+
mlir::triton::gluon::GluonDialect>();
109112
}

include/triton/Dialect/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ add_subdirectory(Triton)
22
add_subdirectory(TritonGPU)
33
add_subdirectory(TritonNvidiaGPU)
44
add_subdirectory(TritonInstrument)
5+
add_subdirectory(Gluon)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
add_subdirectory(IR)
2+
add_subdirectory(Transforms)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
add_subdirectory(IR)
2+
add_subdirectory(Transforms)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR})
2+
3+
set(LLVM_TARGET_DEFINITIONS GluonDialect.td)
4+
mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=gluon)
5+
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=gluon)
6+
mlir_tablegen(Ops.h.inc -gen-op-decls)
7+
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
8+
add_mlir_doc(GluonDialect GluonDialect dialects/ -gen-dialect-doc)
9+
add_public_tablegen_target(GluonTableGen)
10+
11+
set(LLVM_TARGET_DEFINITIONS GluonAttrDefs.td)
12+
mlir_tablegen(GluonAttrDefs.h.inc -gen-attrdef-decls)
13+
mlir_tablegen(GluonAttrDefs.cpp.inc -gen-attrdef-defs)
14+
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
15+
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
16+
add_public_tablegen_target(GluonAttrDefsIncGen)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#pragma once
2+
#include "triton/Dialect/Triton/IR/Dialect.h"
3+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
4+
5+
#include "triton/Dialect/Gluon/IR/Dialect.h.inc"
6+
7+
#define GET_ATTRDEF_CLASSES
8+
#include "triton/Dialect/Gluon/IR/GluonAttrDefs.h.inc"
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#ifndef GLUON_ATTRDEFS
2+
#define GLUON_ATTRDEFS
3+
4+
include "mlir/IR/AttrTypeBase.td"
5+
include "triton/Dialect/Gluon/IR/GluonDialect.td"
6+
7+
def Gluon_AutoEncodingAttr : AttrDef<Gluon_Dialect, "AutoEncoding"> {
8+
let mnemonic = "auto_encoding";
9+
let attrName = "gluon.auto_encoding";
10+
let description = [{
11+
An encoding that is inferred from neighboring ops in the graph.
12+
}];
13+
}
14+
15+
#endif
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#ifndef GLUON_DIALECT
2+
#define GLUON_DIALECT
3+
4+
include "mlir/IR/OpBase.td"
5+
6+
def Gluon_Dialect : Dialect {
7+
let name = "gluon";
8+
let cppNamespace = "::mlir::triton::gluon";
9+
let description = [{
10+
Gluon dialect.
11+
}];
12+
13+
let dependentDialects = [
14+
"triton::TritonDialect",
15+
"triton::gpu::TritonGPUDialect",
16+
"mlir::gpu::GPUDialect",
17+
];
18+
let useDefaultAttributePrinterParser = 1;
19+
let usePropertiesForAttributes = 1;
20+
}
21+
22+
#endif
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
set(LLVM_TARGET_DEFINITIONS Passes.td)
2+
mlir_tablegen(Passes.h.inc -gen-pass-decls -name Gluon)
3+
add_public_tablegen_target(GluonTransformsIncGen)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#pragma once
2+
#include "mlir/IR/BuiltinOps.h"
3+
#include "mlir/Pass/Pass.h"
4+
#include "triton/Dialect/Gluon/IR/Dialect.h"
5+
#include <memory>
6+
7+
namespace mlir::triton::gluon {
8+
9+
#define GEN_PASS_DECL
10+
#define GEN_PASS_REGISTRATION
11+
#include "triton/Dialect/Gluon/Transforms/Passes.h.inc"
12+
13+
} // namespace mlir::triton::gluon

0 commit comments

Comments
 (0)