Skip to content

Commit 47d57c1

Browse files
Pangorawgiordanowsmoses
authored
Triton forward mode AD (#1578)
* Triton forward mode AD * fmt build * header * Bump enzyme * Update TritonDerivatives.td --------- Co-authored-by: Mosè Giordano <[email protected]> Co-authored-by: William Moses <[email protected]>
1 parent 3bb71d5 commit 47d57c1

File tree

5 files changed

+207
-0
lines changed

5 files changed

+207
-0
lines changed

src/enzyme_ad/jax/BUILD

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,20 @@ td_library(
266266
],
267267
)
268268

269+
gentbl_cc_library(
270+
name = "triton-derivatives",
271+
tbl_outs = [(
272+
["-gen-mlir-derivatives"],
273+
"Implementations/TritonDerivatives.inc",
274+
)],
275+
tblgen = "@enzyme//:enzyme-tblgen",
276+
td_file = "Implementations/TritonDerivatives.td",
277+
td_srcs = [
278+
"Implementations/TritonDerivatives.td",
279+
],
280+
deps = [":ImplementationsCommonTdFiles"],
281+
)
282+
269283
gentbl_cc_library(
270284
name = "mhlo-derivatives",
271285
tbl_outs = [(
@@ -874,6 +888,7 @@ cc_library(
874888
":enzymexla-derivatives",
875889
":mhlo-derivatives",
876890
":stablehlo-derivatives",
891+
":triton-derivatives",
877892
"//src/external/isl:Isl",
878893
"@com_google_absl//absl/status",
879894
"@com_google_absl//absl/status:statusor",
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
//===- TritonAutoDiffOpInterfaceImpl.cpp - Interface external model -------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file contains the external model implementation of the automatic
10+
// differentiation op interfaces for the MLIR tt dialect.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h"
15+
#include "Enzyme/MLIR/Interfaces/AutoDiffOpInterface.h"
16+
#include "Enzyme/MLIR/Interfaces/GradientUtils.h"
17+
#include "Enzyme/MLIR/Interfaces/GradientUtilsReverse.h"
18+
19+
#include "triton/Dialect/Triton/IR/Dialect.h"
20+
21+
#include "src/enzyme_ad/jax/Implementations/XLADerivatives.h"
22+
23+
using namespace mlir;
24+
using namespace mlir::enzyme;
25+
using namespace mlir::triton;
26+
27+
namespace {
28+
29+
#include "src/enzyme_ad/jax/Implementations/TritonDerivatives.inc"
30+
31+
class TritonPointerTypeInterface
32+
: public AutoDiffTypeInterface::ExternalModel<TritonPointerTypeInterface,
33+
triton::PointerType> {
34+
public:
35+
mlir::Value createNullValue(mlir::Type self, OpBuilder &builder,
36+
Location loc) const {
37+
llvm_unreachable("TODO");
38+
}
39+
40+
Value createAddOp(Type self, OpBuilder &builder, Location loc, Value a,
41+
Value b) const {
42+
llvm_unreachable("TODO");
43+
}
44+
45+
Value createConjOp(Type self, OpBuilder &builder, Location loc,
46+
Value a) const {
47+
llvm_unreachable("TODO");
48+
}
49+
50+
Type getShadowType(Type self, unsigned width) const {
51+
assert(width == 1 && "unsupported width != 1");
52+
return self;
53+
}
54+
55+
bool isMutable(Type self) const { return true; }
56+
57+
LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc,
58+
Value val) const {
59+
// TODO inspect val and memset corresponding size
60+
return failure();
61+
}
62+
63+
bool isZero(Type self, Value val) const { return false; }
64+
bool isZeroAttr(Type self, Attribute attr) const { return false; }
65+
};
66+
67+
class AutoDiffTritonFuncFunctionInterface
68+
: public AutoDiffFunctionInterface::ExternalModel<
69+
AutoDiffTritonFuncFunctionInterface, triton::FuncOp> {
70+
public:
71+
void transformResultTypes(Operation *self,
72+
SmallVectorImpl<Type> &returnTypes) const {}
73+
74+
Operation *createCall(Operation *self, OpBuilder &builder, Location loc,
75+
ValueRange args) const {
76+
return triton::CallOp::create(builder, loc, cast<triton::FuncOp>(self),
77+
args);
78+
}
79+
80+
Operation *createReturn(Operation *self, OpBuilder &builder, Location loc,
81+
ValueRange retArgs) const {
82+
return triton::ReturnOp::create(builder, loc, retArgs);
83+
}
84+
};
85+
86+
} // end anonymous namespace
87+
88+
void mlir::enzyme::registerTritonDialectAutoDiffInterface(
89+
DialectRegistry &registry) {
90+
registry.addExtension(+[](MLIRContext *context, triton::TritonDialect *) {
91+
registerInterfaces(context);
92+
triton::FuncOp::attachInterface<AutoDiffTritonFuncFunctionInterface>(
93+
*context);
94+
triton::PointerType::attachInterface<TritonPointerTypeInterface>(*context);
95+
});
96+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
include "src/enzyme_ad/jax/Implementations/Common.td"
2+
3+
class TritonDerivative<string opName_, dag patternToMatch, list<dag> resultOps, dag forwardOps=(ForwardFromSummedReverse)> : MLIRDerivative<"triton", opName_, patternToMatch, resultOps, forwardOps>;
4+
5+
class TritonInst<string m, string postopt="", string preopt=""> : Inst<m, "triton", postopt, preopt>;
6+
7+
class TritonMemoryIdentityOp<string opName_, list<int> ptrargs_, list<int> storedargs_ = [], dag patternToMatch=(Unimplemented), list<dag> reverse_ = []> : MemoryIdentityOp<"triton", opName_, ptrargs_, storedargs_, patternToMatch, reverse_>;
8+
9+
class TritonReadOnlyIdentityOp<string opName_, list<int> ptrargs_ = [0], dag patternToMatch=(Unimplemented), list<dag> reverse_ = []> : ReadOnlyIdentityOp<"triton", opName_, ptrargs_, patternToMatch, reverse_>;
10+
11+
class ArithConstantFP<string m> : ConstantFP<m, "arith", "ConstantOp", "mlir::ElementsAttr">;
12+
13+
class TritonInactiveOp<string m> : InactiveOp<"triton", m>;
14+
15+
class TritonReturnOp<string m> : ReturnOp<"triton", m>;
16+
17+
def FpToFp : TritonInst<"FpToFpOp">;
18+
def PreciseDivF : TritonInst<"PreciseDivFOp">;
19+
def MakeRange : TritonInst<"MakeRangeOp">;
20+
21+
def : TritonReturnOp<"ReturnOp">;
22+
23+
def : TritonInactiveOp<"AssertOp">;
24+
def : TritonInactiveOp<"MakeRangeOp">;
25+
def : TritonInactiveOp<"PrintOp">;
26+
27+
def : ReadOnlyIdentityOp<"triton", "AddPtrOp", [0]>;
28+
def : ReadOnlyIdentityOp<"triton", "AdvanceOp", [0]>;
29+
def : ReadOnlyIdentityOp<"triton", "LoadOp", [0]>;
30+
def : ReadOnlyIdentityOp<"triton", "SplatOp", [0]>;
31+
def : MemoryIdentityOp<"triton", "StoreOp", [1], [0]>;
32+
33+
def FpToFpRoundingMode : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
34+
op.getRoundingAttr();
35+
}]>;
36+
37+
def : TritonDerivative<"FpToFpOp", (Op $x),
38+
[
39+
(FpToFp (TypeOf $x), (DiffeRet), (FpToFpRoundingMode))
40+
]
41+
>;

src/enzyme_ad/jax/Implementations/XLADerivatives.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@ void registerMHLODialectAutoDiffInterface(mlir::DialectRegistry &registry);
1414
void registerStableHLODialectAutoDiffInterface(mlir::DialectRegistry &registry);
1515
void registerCHLODialectAutoDiffInterface(mlir::DialectRegistry &registry);
1616
void registerEnzymeXLADialectAutoDiffInterface(mlir::DialectRegistry &registry);
17+
void registerTritonDialectAutoDiffInterface(mlir::DialectRegistry &registry);
1718

1819
static inline void
1920
registerXLAAutoDiffInterfaces(mlir::DialectRegistry &registry) {
2021
registerMHLODialectAutoDiffInterface(registry);
2122
registerStableHLODialectAutoDiffInterface(registry);
2223
registerCHLODialectAutoDiffInterface(registry);
2324
registerEnzymeXLADialectAutoDiffInterface(registry);
25+
registerTritonDialectAutoDiffInterface(registry);
2426
}
2527
} // namespace enzyme
2628
} // namespace mlir
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// RUN: enzymexlamlir-opt %s --enzyme-wrap="infn=add_kernel outfn= argTys=enzyme_dup,enzyme_const,enzyme_dup,enzyme_const retTys= mode=ForwardMode" | FileCheck %s
2+
3+
module {
4+
tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
5+
%c1024_i32 = arith.constant 1024 : i32
6+
%0 = tt.get_program_id x : i32
7+
%1 = arith.muli %0, %c1024_i32 : i32
8+
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
9+
%3 = tt.splat %1 : i32 -> tensor<1024xi32>
10+
%4 = arith.addi %3, %2 : tensor<1024xi32>
11+
%5 = tt.splat %arg3 : i32 -> tensor<1024xi32>
12+
%6 = arith.cmpi slt, %4, %5 : tensor<1024xi32>
13+
%7 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
14+
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
15+
%9 = tt.load %8, %6 : tensor<1024x!tt.ptr<f32>>
16+
%10 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
17+
%11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
18+
%12 = tt.load %11, %6 : tensor<1024x!tt.ptr<f32>>
19+
%13 = arith.addf %9, %12 : tensor<1024xf32>
20+
%14 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
21+
%15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
22+
tt.store %15, %13, %6 : tensor<1024x!tt.ptr<f32>>
23+
tt.return
24+
}
25+
}
26+
27+
// CHECK: tt.func @add_kernel(%[[arg0:.+]]: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %[[arg1:.+]]: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %[[arg2:.+]]: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %[[arg3:.+]]: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %[[arg4:.+]]: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %[[arg5:.+]]: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
28+
// CHECK-NEXT: %[[c1024_i32:.+]] = arith.constant 1024 : i32
29+
// CHECK-NEXT: %[[v0:.+]] = tt.get_program_id x : i32
30+
// CHECK-NEXT: %[[v1:.+]] = arith.muli %[[v0]], %[[c1024_i32]] : i32
31+
// CHECK-NEXT: %[[v2:.+]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
32+
// CHECK-NEXT: %[[v3:.+]] = tt.splat %[[v1]] : i32 -> tensor<1024xi32>
33+
// CHECK-NEXT: %[[v4:.+]] = arith.addi %[[v3]], %[[v2]] : tensor<1024xi32>
34+
// CHECK-NEXT: %[[v5:.+]] = tt.splat %[[arg5]] : i32 -> tensor<1024xi32>
35+
// CHECK-NEXT: %[[v6:.+]] = arith.cmpi slt, %[[v4]], %[[v5]] : tensor<1024xi32>
36+
// CHECK-NEXT: %[[v7:.+]] = tt.splat %[[arg1]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
37+
// CHECK-NEXT: %[[v8:.+]] = tt.splat %[[arg0]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
38+
// CHECK-NEXT: %[[v9:.+]] = tt.addptr %[[v7]], %[[v4]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
39+
// CHECK-NEXT: %[[v10:.+]] = tt.addptr %[[v8]], %[[v4]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
40+
// CHECK-NEXT: %[[v11:.+]] = tt.load %[[v9]], %[[v6]] : tensor<1024x!tt.ptr<f32>>
41+
// CHECK-NEXT: %[[v12:.+]] = tt.load %[[v10]], %[[v6]] : tensor<1024x!tt.ptr<f32>>
42+
// CHECK-NEXT: %[[v13:.+]] = tt.splat %[[arg2]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
43+
// CHECK-NEXT: %[[v14:.+]] = tt.addptr %[[v13]], %[[v4]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
44+
// CHECK-NEXT: %[[v15:.+]] = tt.load %[[v14]], %[[v6]] : tensor<1024x!tt.ptr<f32>>
45+
// CHECK-NEXT: %[[v16:.+]] = arith.addf %[[v12]], %[[v15]] : tensor<1024xf32>
46+
// CHECK-NEXT: %[[v17:.+]] = tt.splat %[[arg4]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
47+
// CHECK-NEXT: %[[v18:.+]] = tt.splat %[[arg3]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
48+
// CHECK-NEXT: %[[v19:.+]] = tt.addptr %[[v17]], %[[v4]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
49+
// CHECK-NEXT: %[[v20:.+]] = tt.addptr %[[v18]], %[[v4]] : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
50+
// CHECK-NEXT: tt.store %[[v19]], %[[v11]], %[[v6]] : tensor<1024x!tt.ptr<f32>>
51+
// CHECK-NEXT: tt.store %[[v20]], %[[v16]], %[[v6]] : tensor<1024x!tt.ptr<f32>>
52+
// CHECK-NEXT: tt.return
53+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)