diff --git a/src/enzyme_ad/jax/Implementations/TritonExtAutoDiffOpInterfaceImp.cpp b/src/enzyme_ad/jax/Implementations/TritonExtAutoDiffOpInterfaceImp.cpp new file mode 100644 index 000000000..c6abe8959 --- /dev/null +++ b/src/enzyme_ad/jax/Implementations/TritonExtAutoDiffOpInterfaceImp.cpp @@ -0,0 +1,278 @@ +//===- TritonExtAutoDiffOpInterfaceImpl.cpp - Interface external model ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the external model implementation of the automatic +// differentiation op interfaces for the MLIR triton_ext dialect. +// +//===----------------------------------------------------------------------===// + +#include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h" +#include "Enzyme/MLIR/Interfaces/AutoDiffOpInterface.h" +#include "Enzyme/MLIR/Interfaces/GradientUtils.h" + +#include "src/enzyme_ad/jax/Dialect/TritonExt/Dialect.h" +#include "src/enzyme_ad/jax/Implementations/XLADerivatives.h" + +using namespace mlir; +using namespace mlir::enzyme; +using namespace mlir::enzymexla; + +namespace { + +// this assumes no tuple in either args or results. +static std::optional +findAliasedOperand(ArrayAttr outputOperandAliases, unsigned outputIndex) { + for (auto attr : outputOperandAliases) { + auto alias = cast(attr); + if (alias.getOutputTupleIndices()[0] != outputIndex) + continue; + assert(alias.getOutputTupleIndices().size() == 1); + assert(alias.getOperandTupleIndices().empty()); + return alias.getOperandIndex(); + } + return std::nullopt; +} + +class AutoDiffTritonCallFwd + : public AutoDiffOpInterface::ExternalModel { +public: + LogicalResult createForwardModeTangent(Operation *orig, OpBuilder &builder, + MGradientUtils *gutils) const { + DerivativeMode mode = DerivativeMode::ForwardMode; + + auto callOp = cast(orig); + + for (auto [i, arg] : llvm::enumerate(callOp.getInputs())) { + if (!isa(arg.getType())) { + orig->emitError() + << "unsupported forward rule of triton kernel call with non array " + "return at return #" + << i << " of type " << arg.getType() << "."; + return failure(); + } + } + + for (auto [i, res] : llvm::enumerate(callOp->getResults())) { + if (!isa(res.getType())) { + orig->emitError() + << "unsupported forward rule of triton kernel call with non array " + "return at return #" + << i << " of type " << res.getType() << "."; + return failure(); + } + } + + auto output_operand_aliases = callOp.getOutputOperandAliases(); + auto operandLayouts = dyn_cast_or_null( + callOp.getOperandLayouts().value_or(nullptr)); + auto resultLayouts = dyn_cast_or_null( + callOp.getResultLayouts().value_or(nullptr)); + + Operation *callee = + SymbolTable::lookupNearestSymbolFrom(callOp, callOp.getFn()); + auto fn = cast(callee); + + size_t width = gutils->width; + + int numInputs = callOp.getInputs().size(); + int narg = numInputs + orig->getNumResults(); + + std::vector RetActivity; + std::vector returnPrimal; + std::vector returnShadow; + + // Unless there is aliasing, returns values arguments are assumed to + // appended to the argument list in the triton kernel. + SmallVector operandIndexMap; + + unsigned argCnt = 0; + + std::vector ArgActivity; + for (auto arg : callOp.getInputs()) { + auto act = gutils->isConstantValue(arg) ? DIFFE_TYPE::CONSTANT + : DIFFE_TYPE::DUP_ARG; + operandIndexMap.push_back(argCnt); + ArgActivity.push_back(act); + argCnt++; + if (act == DIFFE_TYPE::DUP_ARG) + argCnt++; + } + + for (auto [i, res] : llvm::enumerate(callOp.getResults())) { + auto aliasedOperandIndex = findAliasedOperand(output_operand_aliases, i); + if (!aliasedOperandIndex.has_value()) { + auto act = gutils->isConstantValue(res) ? DIFFE_TYPE::CONSTANT + : DIFFE_TYPE::DUP_ARG; + ArgActivity.push_back(act); + } else { + narg--; + } + } + + auto type_args = gutils->TA.getAnalyzedTypeInfo(fn); + + bool freeMemory = true; + + std::vector volatile_args(narg, false); + + auto forwardFn = gutils->Logic.CreateForwardDiff( + fn, RetActivity, ArgActivity, gutils->TA, returnPrimal, mode, + freeMemory, width, + /* addedType */ nullptr, type_args, volatile_args, + /* augmented */ nullptr, gutils->omp, gutils->postpasses, + gutils->verifyPostPasses, gutils->strongZero); + + SmallVector fwdArguments; + SmallVector returnTypes; + + // let's assume the same layout for a value and its shadow. + SmallVector newOperandLayouts; + SmallVector newResultLayouts; + + unsigned argIdx = 0; + for (auto &&[arg, act] : llvm::zip(callOp.getInputs(), ArgActivity)) { + fwdArguments.push_back(gutils->getNewFromOriginal(arg)); + + if (operandLayouts) { + newOperandLayouts.push_back(operandLayouts[argIdx]); + if (act == DIFFE_TYPE::DUP_ARG) + newOperandLayouts.push_back(operandLayouts[argIdx]); + } + argIdx++; + + if (act == DIFFE_TYPE::DUP_ARG) + fwdArguments.push_back(gutils->invertPointerM(arg, builder)); + } + + SmallVector newOutputOperandAliases; + + unsigned naliased = 0; + for (auto &&[i, res] : llvm::enumerate(callOp->getResults())) { + auto aliasedOperandIndex = findAliasedOperand(output_operand_aliases, i); + + DIFFE_TYPE act; + if (aliasedOperandIndex.has_value()) { + naliased++; + + act = ArgActivity[*aliasedOperandIndex]; + + auto newOperandIndex = operandIndexMap[*aliasedOperandIndex]; + int64_t newResultIndex = returnTypes.size(); + newOutputOperandAliases.push_back( + stablehlo::OutputOperandAliasAttr::get( + callOp.getContext(), ArrayRef{newResultIndex}, + newOperandIndex, ArrayRef{})); + + if (act == DIFFE_TYPE::DUP_ARG) { + newOutputOperandAliases.push_back( + stablehlo::OutputOperandAliasAttr::get( + callOp.getContext(), ArrayRef{newResultIndex + 1}, + newOperandIndex + 1, ArrayRef{})); + } + } else { + act = ArgActivity[i - naliased + numInputs]; + } + + if (resultLayouts) { + newResultLayouts.push_back(resultLayouts[i]); + if (act == DIFFE_TYPE::DUP_ARG) + newResultLayouts.push_back(resultLayouts[i]); + } + + returnTypes.push_back(res.getType()); + if (act == DIFFE_TYPE::DUP_ARG) + returnTypes.push_back( + cast(res.getType()).getShadowType(width)); + } + + SmallVector nestedRefs = { + FlatSymbolRefAttr::get( + forwardFn->getParentOfType().getSymNameAttr()), + FlatSymbolRefAttr::get( + StringAttr::get(callOp.getContext(), forwardFn.getName()))}; + auto fnRef = SymbolRefAttr::get( + callOp.getContext(), + forwardFn->getParentOfType().getSymName(), + nestedRefs); + + Value gridx = gutils->getNewFromOriginal(callOp.getGridx()), + gridy = gutils->getNewFromOriginal(callOp.getGridy()), + gridz = gutils->getNewFromOriginal(callOp.getGridz()); + + Value clusterx = gutils->getNewFromOriginal(callOp.getClusterx()), + clustery = gutils->getNewFromOriginal(callOp.getClustery()), + clusterz = gutils->getNewFromOriginal(callOp.getClusterz()); + + Attribute newOperandLayoutsAttr = + operandLayouts ? ArrayAttr::get(callOp.getContext(), newOperandLayouts) + : nullptr; + Attribute newResultLayoutsAttr = + resultLayouts ? ArrayAttr::get(callOp.getContext(), newResultLayouts) + : nullptr; + + auto fwdCallOp = triton_ext::TritonCallOp::create( + builder, callOp.getLoc(), TypeRange(returnTypes), + /*fn*/ fnRef, + + gridx, gridy, gridz, + + clusterx, clustery, clusterz, + + ValueRange(fwdArguments), + /* backendConfig */ StringAttr::get(callOp.getContext(), ""), + newOperandLayoutsAttr, newResultLayoutsAttr, + /* argAttrs */ mlir::ArrayAttr::get(callOp.getContext(), {}), + /* resAttrs */ mlir::ArrayAttr::get(callOp.getContext(), {}), + ArrayAttr::get(callOp.getContext(), newOutputOperandAliases), + /* xla_side_effect_free */ nullptr); + + SmallVector primals; + primals.reserve(callOp->getNumResults()); + + naliased = 0; + int fwdIndex = 0; + for (auto &&[i, ret] : llvm::enumerate(callOp.getResults())) { + auto fwdRet = fwdCallOp.getResult(fwdIndex); + primals.push_back(fwdRet); + + fwdIndex++; + + auto aliasedOperandIndex = findAliasedOperand(output_operand_aliases, i); + + DIFFE_TYPE act; + if (aliasedOperandIndex.has_value()) { + act = ArgActivity[*aliasedOperandIndex]; + naliased++; + } else { + act = ArgActivity[i - naliased + numInputs]; + } + + if (act == DIFFE_TYPE::DUP_ARG) { + gutils->setDiffe(ret, fwdCallOp.getResult(fwdIndex), builder); + fwdIndex++; + } + } + + auto newOp = gutils->getNewFromOriginal(orig); + gutils->replaceOrigOpWith(orig, primals); + gutils->erase(newOp); + + return success(); + } +}; + +} // end anonymous namespace + +void mlir::enzyme::registerTritonExtDialectAutoDiffInterface( + mlir::DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *context, + triton_ext::TritonExtDialect *) { + triton_ext::TritonCallOp::attachInterface(*context); + }); +} diff --git a/src/enzyme_ad/jax/Implementations/XLADerivatives.h b/src/enzyme_ad/jax/Implementations/XLADerivatives.h index abb7b9546..4b09b2d69 100644 --- a/src/enzyme_ad/jax/Implementations/XLADerivatives.h +++ b/src/enzyme_ad/jax/Implementations/XLADerivatives.h @@ -15,6 +15,7 @@ void registerStableHLODialectAutoDiffInterface(mlir::DialectRegistry ®istry); void registerCHLODialectAutoDiffInterface(mlir::DialectRegistry ®istry); void registerEnzymeXLADialectAutoDiffInterface(mlir::DialectRegistry ®istry); void registerTritonDialectAutoDiffInterface(mlir::DialectRegistry ®istry); +void registerTritonExtDialectAutoDiffInterface(mlir::DialectRegistry ®istry); static inline void registerXLAAutoDiffInterfaces(mlir::DialectRegistry ®istry) { @@ -23,6 +24,7 @@ registerXLAAutoDiffInterfaces(mlir::DialectRegistry ®istry) { registerCHLODialectAutoDiffInterface(registry); registerEnzymeXLADialectAutoDiffInterface(registry); registerTritonDialectAutoDiffInterface(registry); + registerTritonExtDialectAutoDiffInterface(registry); } } // namespace enzyme } // namespace mlir diff --git a/test/lit_tests/diffrules/triton/add.mlir b/test/lit_tests/diffrules/triton/add.mlir index 22621d68b..c7e0445c6 100644 --- a/test/lit_tests/diffrules/triton/add.mlir +++ b/test/lit_tests/diffrules/triton/add.mlir @@ -1,53 +1,76 @@ -// RUN: enzymexlamlir-opt %s --enzyme-wrap="infn=add_kernel outfn= argTys=enzyme_dup,enzyme_const,enzyme_dup,enzyme_const retTys= mode=ForwardMode" | FileCheck %s +// RUN: enzymexlamlir-opt %s --enzyme-wrap="infn=main outfn= argTys=enzyme_dup,enzyme_dup,enzyme_dup,enzyme_const retTys=enzyme_dup,enzyme_dup,enzyme_dup mode=ForwardMode" --canonicalize | FileCheck %s module { - tt.func public @add_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { - %c1024_i32 = arith.constant 1024 : i32 - %0 = tt.get_program_id x : i32 - %1 = arith.muli %0, %c1024_i32 : i32 - %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> - %3 = tt.splat %1 : i32 -> tensor<1024xi32> - %4 = arith.addi %3, %2 : tensor<1024xi32> - %5 = tt.splat %arg3 : i32 -> tensor<1024xi32> - %6 = arith.cmpi slt, %4, %5 : tensor<1024xi32> - %7 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> - %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> - %9 = tt.load %8, %6 : tensor<1024x!tt.ptr> - %10 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> - %11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> - %12 = tt.load %11, %6 : tensor<1024x!tt.ptr> - %13 = arith.addf %9, %12 : tensor<1024xf32> - %14 = tt.splat %arg2 : !tt.ptr -> tensor<1024x!tt.ptr> - %15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> - tt.store %15, %13, %6 : tensor<1024x!tt.ptr> - tt.return + enzymexla_tt_ext.module @add_kernel_tt { + builtin.module @add_kernel_inner { + tt.func public @add_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, + %arg3: i32 {tt.divisibility = 16 : i32}, + %arg2: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.splat %1 : i32 -> tensor<1024xi32> + %4 = arith.addi %3, %2 : tensor<1024xi32> + %5 = tt.splat %arg3 : i32 -> tensor<1024xi32> + %6 = arith.cmpi slt, %4, %5 : tensor<1024xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %9 = tt.load %8, %6 : tensor<1024x!tt.ptr> + %10 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + %12 = tt.load %11, %6 : tensor<1024x!tt.ptr> + %13 = arith.addf %9, %12 : tensor<1024xf32> + %14 = tt.splat %arg2 : !tt.ptr -> tensor<1024x!tt.ptr> + %15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr>, tensor<1024xi32> + tt.store %15, %13, %6 : tensor<1024x!tt.ptr> + tt.return + } + } + } + func.func @main(%arg0: tensor<1024xf32>, %arg1: tensor<1024xf32>, %arg2: tensor<1024xf32>, %arg3: tensor) -> (tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>) { + %c_0 = stablehlo.constant dense<1> : tensor + %c_1 = stablehlo.constant dense<16> : tensor + %0:3 = enzymexla_tt_ext.call @add_kernel_tt::@add_kernel_inner::@add_kernel clusters in (%c_0, %c_0, %c_0) blocks in(%c_1, %c_0, %c_0) (%arg0, %arg1, %arg3) {output_operand_aliases = [#stablehlo.output_operand_alias, #stablehlo.output_operand_alias]} : (tensor<1024xf32>, tensor<1024xf32>, tensor) -> (tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>) + return %0#0, %0#1, %0#2 : tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32> } } -// CHECK: tt.func @add_kernel(%[[arg0:.+]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[arg1:.+]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[arg2:.+]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[arg3:.+]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[arg4:.+]]: !tt.ptr {tt.divisibility = 16 : i32}, %[[arg5:.+]]: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { -// CHECK-NEXT: %[[c1024_i32:.+]] = arith.constant 1024 : i32 -// CHECK-NEXT: %[[v0:.+]] = tt.get_program_id x : i32 -// CHECK-NEXT: %[[v1:.+]] = arith.muli %[[v0]], %[[c1024_i32]] : i32 -// CHECK-NEXT: %[[v2:.+]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> -// CHECK-NEXT: %[[v3:.+]] = tt.splat %[[v1]] : i32 -> tensor<1024xi32> -// CHECK-NEXT: %[[v4:.+]] = arith.addi %[[v3]], %[[v2]] : tensor<1024xi32> -// CHECK-NEXT: %[[v5:.+]] = tt.splat %[[arg5]] : i32 -> tensor<1024xi32> -// CHECK-NEXT: %[[v6:.+]] = arith.cmpi slt, %[[v4]], %[[v5]] : tensor<1024xi32> -// CHECK-NEXT: %[[v7:.+]] = tt.splat %[[arg1]] : !tt.ptr -> tensor<1024x!tt.ptr> -// CHECK-NEXT: %[[v8:.+]] = tt.splat %[[arg0]] : !tt.ptr -> tensor<1024x!tt.ptr> -// CHECK-NEXT: %[[v9:.+]] = tt.addptr %[[v7]], %[[v4]] : tensor<1024x!tt.ptr>, tensor<1024xi32> -// CHECK-NEXT: %[[v10:.+]] = tt.addptr %[[v8]], %[[v4]] : tensor<1024x!tt.ptr>, tensor<1024xi32> -// CHECK-NEXT: %[[v11:.+]] = tt.load %[[v9]], %[[v6]] : tensor<1024x!tt.ptr> -// CHECK-NEXT: %[[v12:.+]] = tt.load %[[v10]], %[[v6]] : tensor<1024x!tt.ptr> -// CHECK-NEXT: %[[v13:.+]] = tt.splat %[[arg2]] : !tt.ptr -> tensor<1024x!tt.ptr> -// CHECK-NEXT: %[[v14:.+]] = tt.addptr %[[v13]], %[[v4]] : tensor<1024x!tt.ptr>, tensor<1024xi32> -// CHECK-NEXT: %[[v15:.+]] = tt.load %[[v14]], %[[v6]] : tensor<1024x!tt.ptr> -// CHECK-NEXT: %[[v16:.+]] = arith.addf %[[v12]], %[[v15]] : tensor<1024xf32> -// CHECK-NEXT: %[[v17:.+]] = tt.splat %[[arg4]] : !tt.ptr -> tensor<1024x!tt.ptr> -// CHECK-NEXT: %[[v18:.+]] = tt.splat %[[arg3]] : !tt.ptr -> tensor<1024x!tt.ptr> -// CHECK-NEXT: %[[v19:.+]] = tt.addptr %[[v17]], %[[v4]] : tensor<1024x!tt.ptr>, tensor<1024xi32> -// CHECK-NEXT: %[[v20:.+]] = tt.addptr %[[v18]], %[[v4]] : tensor<1024x!tt.ptr>, tensor<1024xi32> -// CHECK-NEXT: tt.store %[[v19]], %[[v11]], %[[v6]] : tensor<1024x!tt.ptr> -// CHECK-NEXT: tt.store %[[v20]], %[[v16]], %[[v6]] : tensor<1024x!tt.ptr> -// CHECK-NEXT: tt.return +// CHECK: tt.func private @fwddiffeadd_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { +// CHECK-NEXT: %c1024_i32 = arith.constant 1024 : i32 +// CHECK-NEXT: %[[v0:.+]] = tt.get_program_id x : i32 +// CHECK-NEXT: %[[v1:.+]] = arith.muli %[[v0]], %c1024_i32 : i32 +// CHECK-NEXT: %[[v2:.+]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> +// CHECK-NEXT: %[[v3:.+]] = tt.splat %[[v1]] : i32 -> tensor<1024xi32> +// CHECK-NEXT: %[[v4:.+]] = arith.addi %[[v3]], %[[v2]] : tensor<1024xi32> +// CHECK-NEXT: %[[v5:.+]] = tt.splat %arg4 : i32 -> tensor<1024xi32> +// CHECK-NEXT: %[[v6:.+]] = arith.cmpi slt, %[[v4]], %[[v5]] : tensor<1024xi32> +// CHECK-NEXT: %[[v7:.+]] = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK-NEXT: %[[v8:.+]] = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK-NEXT: %[[v9:.+]] = tt.addptr %[[v7]], %[[v4]] : tensor<1024x!tt.ptr>, tensor<1024xi32> +// CHECK-NEXT: %[[v10:.+]] = tt.addptr %[[v8]], %[[v4]] : tensor<1024x!tt.ptr>, tensor<1024xi32> +// CHECK-NEXT: %[[v11:.+]] = tt.load %[[v9]], %[[v6]] : tensor<1024x!tt.ptr> +// CHECK-NEXT: %[[v12:.+]] = tt.load %[[v10]], %[[v6]] : tensor<1024x!tt.ptr> +// CHECK-NEXT: %[[v13:.+]] = tt.splat %arg3 : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK-NEXT: %[[v14:.+]] = tt.splat %arg2 : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK-NEXT: %[[v15:.+]] = tt.addptr %[[v13]], %[[v4]] : tensor<1024x!tt.ptr>, tensor<1024xi32> +// CHECK-NEXT: %[[v16:.+]] = tt.addptr %[[v14]], %[[v4]] : tensor<1024x!tt.ptr>, tensor<1024xi32> +// CHECK-NEXT: %[[v17:.+]] = tt.load %[[v15]], %[[v6]] : tensor<1024x!tt.ptr> +// CHECK-NEXT: %[[v18:.+]] = tt.load %[[v16]], %[[v6]] : tensor<1024x!tt.ptr> +// CHECK-NEXT: %[[v19:.+]] = arith.addf %[[v11]], %[[v17]] : tensor<1024xf32> +// CHECK-NEXT: %[[v20:.+]] = arith.addf %[[v12]], %[[v18]] : tensor<1024xf32> +// CHECK-NEXT: %[[v21:.+]] = tt.splat %arg6 : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK-NEXT: %[[v22:.+]] = tt.splat %arg5 : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK-NEXT: %[[v23:.+]] = tt.addptr %[[v21]], %[[v4]] : tensor<1024x!tt.ptr>, tensor<1024xi32> +// CHECK-NEXT: %[[v24:.+]] = tt.addptr %[[v22]], %[[v4]] : tensor<1024x!tt.ptr>, tensor<1024xi32> +// CHECK-NEXT: tt.store %[[v23]], %[[v19]], %[[v6]] : tensor<1024x!tt.ptr> +// CHECK-NEXT: tt.store %[[v24]], %[[v20]], %[[v6]] : tensor<1024x!tt.ptr> +// CHECK-NEXT: tt.return +// CHECK-NEXT: } + +// CHECK: func.func @main(%arg0: tensor<1024xf32>, %arg1: tensor<1024xf32>, %arg2: tensor<1024xf32>, %arg3: tensor<1024xf32>, %arg4: tensor<1024xf32>, %arg5: tensor<1024xf32>, %arg6: tensor) -> (tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>) { +// CHECK-NEXT: %c = stablehlo.constant dense<1> : tensor +// CHECK-NEXT: %c_0 = stablehlo.constant dense<16> : tensor +// CHECK-NEXT: %0:6 = enzymexla_tt_ext.call @add_kernel_tt::@add_kernel_inner::@fwddiffeadd_kernel clusters in(%c, %c, %c) blocks in(%c_0, %c, %c) (%arg0, %arg1, %arg2, %arg3, %arg6) {arg_attrs = [], output_operand_aliases = [#stablehlo.output_operand_alias, #stablehlo.output_operand_alias, #stablehlo.output_operand_alias, #stablehlo.output_operand_alias], res_attrs = []} : (tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>, tensor) -> (tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>) +// CHECK-NEXT: return %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32>, tensor<1024xf32> // CHECK-NEXT: }