diff --git a/src/enzyme_ad/jax/Implementations/HLODerivatives.td b/src/enzyme_ad/jax/Implementations/HLODerivatives.td index ba3917ab3..5531012a2 100644 --- a/src/enzyme_ad/jax/Implementations/HLODerivatives.td +++ b/src/enzyme_ad/jax/Implementations/HLODerivatives.td @@ -977,8 +977,6 @@ def : HLODerivative<"Atan2Op", (Op $x, $y), (CheckedDiv (Sub (Mul $x, (Shadow $y)), (Mul $y, (Shadow $x))), (Add (Pow $x, (HLOConstantFP<"2">)), (Pow $y, (HLOConstantFP<"2"> $y)))) >; -def : HLOReadOnlyIdentityOp<"BroadcastInDimOp">; - // Gather Adjoint def ResultDimensionNumbers : GlobalExpr; + +def getBroadcastDimensionsWithBatch : GlobalExpr newBcastDims; + for (auto dim : bcastDims) { + newBcastDims.push_back(to_i64(dim)); + } + if (gutils->width > 1) { + newBcastDims.insert(newBcastDims.begin(), gutils->width); + } + getI64Attr(builder, newBcastDims); +}]>; + +def BroadcastDimsToReductionDims : GlobalExpr reduceDims; + auto outTy = cast(op.getType()); + auto bcastDims = op.getBroadcastDimensions(); + auto inTy = cast(op.getOperand().getType()); + + for (auto en : llvm::enumerate(outTy.getShape())) { + ssize_t bcastIdx = -1; + for (auto en2 : llvm::enumerate(bcastDims)) { + if (en2.value() == en.index()) { + bcastIdx = en2.index(); + break; + } + } + if (bcastIdx != -1) { + if (en.value() != inTy.getShape()[bcastIdx]) { + reduceDims.push_back(en.index()); + assert(inTy.getShape()[bcastIdx] == 1); + } + continue; + } + reduceDims.push_back(en.index()); + } + + if (gutils->width > 1) { + for (int i = 0; i < reduceDims.size(); i++) { + reduceDims[i]++; + } + } + getI64Attr(builder, reduceDims); +}]>; + +def ReduceAddNullValue : GlobalExprgetOperand(0); + auto operandElemType = cast(operand.getType()).getElementType(); + auto bodyTy = RankedTensorType::get({}, operandElemType); + cast( + gutils->getShadowType(bodyTy)).createNullValue(builder, op.getLoc()); +}]>; + +def BroadcastDimensionsToInversePermutation : GlobalExpr(op.getType()); + SmallVector perm(outTy.getRank()); + SmallVector mapping(outTy.getRank(), -1); + for (auto [i, dim] : llvm::enumerate(op.getBroadcastDimensions())) { + mapping[to_i64(dim)] = i; + } + + int next = op.getBroadcastDimensions().size(); + for (int i = 0; i < outTy.getRank(); i++) { + if (mapping[i] == -1) { + mapping[i] = next++; + } + } + + for (int i = 0; i < outTy.getRank(); i++) { + perm[mapping[i]] = i; + } + getI64Attr(builder, perm); +}]>; + +def InsertDeletedReduceDimsType : GlobalExpr reduceDims; + auto outTy = cast(op.getType()); + auto bcastDims = op.getBroadcastDimensions(); + auto inTy = cast(op.getOperand().getType()); + auto outShape = outTy.getShape(); + + for (auto en : llvm::enumerate(outTy.getShape())) { + ssize_t bcastIdx = -1; + for (auto en2 : llvm::enumerate(bcastDims)) { + if (en2.value() == en.index()) { + bcastIdx = en2.index(); + break; + } + } + if (bcastIdx != -1) { + if (en.value() != inTy.getShape()[bcastIdx]) { + reduceDims.push_back(en.index()); + assert(inTy.getShape()[bcastIdx] == 1); + } + continue; + } + reduceDims.push_back(en.index()); + } + + SmallVector reshapeShape(outTy.getRank(), -1); + for (auto [i, sz] : llvm::enumerate(outShape)) { + if (llvm::is_contained(reduceDims, i)) { + reshapeShape[i] = 1; + } else { + reshapeShape[i] = sz; + } + } + + if (gutils->width > 1) { + reshapeShape.insert(reshapeShape.begin(), gutils->width); + } + RankedTensorType::get(reshapeShape, outTy.getElementType()); +}]>; + +def ReduceAdd : HLOInst<"ReduceOp", ")->getResult(0)", "createAddRegion(">; + +def ResultTypeWithBatch : GlobalExpr(op.getType()); + auto outShape = outTy.getShape(); + SmallVector newShape(outShape.begin(), outShape.end()); + if (gutils->width > 1) { + newShape.insert(newShape.begin(), gutils->width); + } + RankedTensorType::get(newShape, outTy.getElementType()); +}]>; + +def InputTypeWithBatch : GlobalExpr(op.getOperand().getType()); + auto inShape = inTy.getShape(); + SmallVector newShape(inShape.begin(), inShape.end()); + if (gutils->width > 1) { + newShape.insert(newShape.begin(), gutils->width); + } + RankedTensorType::get(newShape, inTy.getElementType()); +}]>; + +def : HLODerivative<"BroadcastInDimOp", (Op $x), + [ + ( + Reshape + (InputTypeWithBatch), + ( + Transpose + ( + Reshape + (InsertDeletedReduceDimsType), + (ReduceAdd (DiffeRet), (ReduceAddNullValue), (BroadcastDimsToReductionDims)) + ), + (BroadcastDimensionsToInversePermutation) + ) + ) + ], + ( + BroadcastInDim (ResultTypeWithBatch), (Shadow $x), (getBroadcastDimensionsWithBatch) + )>; diff --git a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp index 0f9449604..e0dc4e62e 100644 --- a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp +++ b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp @@ -1133,116 +1133,6 @@ class AutoDiffWhileRev } }; -class AutoDiffBroadcastInDimRev - : public ReverseAutoDiffOpInterface::ExternalModel< - AutoDiffBroadcastInDimRev, BroadcastInDimOp> { -public: - LogicalResult createReverseModeAdjoint(Operation *orig, OpBuilder &builder, - MGradientUtilsReverse *gutils, - SmallVector caches) const { - auto op = cast(orig); - auto inTy = op.getOperand().getType(); - auto outTy = op.getType(); - auto inDiffe = gutils->diffe(op, builder); - gutils->zeroDiffe(op, builder); - - SmallVector bcastDims(op.getBroadcastDimensions().begin(), - op.getBroadcastDimensions().end()); - - SmallVector reducedDims; - SmallVector iterShape; - for (auto en : llvm::enumerate(outTy.getShape())) { - ssize_t bcastIdx = -1; - for (auto en2 : llvm::enumerate(bcastDims)) { - if (en2.value() == en.index()) { - bcastIdx = en2.index(); - break; - } - } - if (bcastIdx != -1) { - if (en.value() != inTy.getShape()[bcastIdx]) { - reducedDims.push_back(en.index()); - assert(inTy.getShape()[bcastIdx] == 1); - } else { - iterShape.push_back(inTy.getShape()[bcastIdx]); - } - continue; - } - reducedDims.push_back(en.index()); - } - - SmallVector reshapedShape(outTy.getRank(), -1); - for (auto [i, sz] : llvm::enumerate(outTy.getShape())) { - if (llvm::is_contained(reducedDims, i)) { - reshapedShape[i] = 1; - } else { - reshapedShape[i] = sz; - } - } - - SmallVector perm(outTy.getRank(), -1); - SmallVector mapping(outTy.getRank(), -1); - for (auto [i, dim] : llvm::enumerate(bcastDims)) { - mapping[dim] = i; - } - - int next = bcastDims.size(); - for (int i = 0; i < outTy.getRank(); i++) { - if (mapping[i] == -1) { - mapping[i] = next++; - } - } - - for (int i = 0; i < outTy.getRank(); i++) { - perm[mapping[i]] = i; - } - - auto reduceTy = RankedTensorType::get(iterShape, inTy.getElementType()); - auto bodyTy = RankedTensorType::get({}, inTy.getElementType()); - - Value zero = cast(gutils->getShadowType(bodyTy)) - .createNullValue(builder, op.getLoc()); - - auto red = builder.create( - op.getLoc(), TypeRange(gutils->getShadowType(reduceTy)), inDiffe, zero, - reducedDims); - red.getBody().push_back(new Block()); - Block &body = red.getBody().front(); - OpBuilder bodyBuilder(orig->getContext()); - bodyBuilder.setInsertionPointToEnd(&body); - - body.addArgument(bodyTy, op.getLoc()); - body.addArgument(bodyTy, op.getLoc()); - auto add = bodyBuilder.create(op.getLoc(), body.getArgument(0), - body.getArgument(1)); - bodyBuilder.create(op.getLoc(), ValueRange(add)); - - // for simplicity we do grad -> reduce -> reshape (restore 1 dims) -> - // transpose -> reshape - // The repeated reshapes are then eliminated via `enzyme-hlo-opt`. - auto reshapedRed = builder.create( - op.getLoc(), - RankedTensorType::get(reshapedShape, inTy.getElementType()), - red->getResult(0)); - auto transposedVal = - builder.create(op.getLoc(), reshapedRed, perm); - auto res = builder.create( - op.getLoc(), gutils->getShadowType(op.getOperand().getType()), - transposedVal); - - gutils->addToDiffe(op.getOperand(), res, builder); - return success(); - } - - SmallVector cacheValues(Operation *orig, - MGradientUtilsReverse *gutils) const { - return {}; - } - - void createShadowValues(Operation *op, OpBuilder &builder, - MGradientUtilsReverse *gutils) const {} -}; - class AutoDiffSliceRev : public ReverseAutoDiffOpInterface::ExternalModel { @@ -3635,7 +3525,6 @@ void mlir::enzyme::registerStableHLODialectAutoDiffInterface( WhileOp::attachInterface(*context); ReduceOp::attachInterface>(*context); WhileOp::attachInterface>(*context); - BroadcastInDimOp::attachInterface(*context); SliceOp::attachInterface(*context); ReduceOp::attachInterface(*context); ReduceWindowOp::attachInterface(*context); diff --git a/test/lit_tests/diffrules/stablehlo/broadcastindim3.mlir b/test/lit_tests/diffrules/stablehlo/broadcastindim3.mlir new file mode 100644 index 000000000..dcb876931 --- /dev/null +++ b/test/lit_tests/diffrules/stablehlo/broadcastindim3.mlir @@ -0,0 +1,32 @@ +// RUN: enzymexlamlir-opt --enzyme --canonicalize --remove-unnecessary-enzyme-ops --arith-raise --inline --enzyme-hlo-opt %s | FileCheck %s + +module { + func.func private @"Const{typeof(slicing)}(Main.slicing)_autodiff"(%arg0: tensor<1x4x1xf32>) -> (tensor, tensor<1x4x1xf32>) { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<3xf32> + %0 = stablehlo.slice %arg0 [0:1, 0:1, 0:1] : (tensor<1x4x1xf32>) -> tensor<1x1x1xf32> + %1 = stablehlo.reshape %0 : (tensor<1x1x1xf32>) -> tensor<1xf32> + %2 = stablehlo.broadcast_in_dim %1, dims = [0] : (tensor<1xf32>) -> tensor<3xf32> + %3 = stablehlo.multiply %2, %cst_0 : tensor<3xf32> + %4 = stablehlo.multiply %3, %3 : tensor<3xf32> + %5 = stablehlo.reduce(%4 init: %cst) applies stablehlo.add across dimensions = [0] : (tensor<3xf32>, tensor) -> tensor + return %5, %arg0 : tensor, tensor<1x4x1xf32> + } + func.func @main(%arg0: tensor<1x4x1xf32>) -> (tensor<1x4x1xf32>, tensor<1x4x1xf32>) { + %cst = stablehlo.constant dense<1.000000e+00> : tensor + %0:2 = enzyme.autodiff @"Const{typeof(slicing)}(Main.slicing)_autodiff"(%arg0, %cst) {activity = [#enzyme], ret_activity = [#enzyme, #enzyme]} : (tensor<1x4x1xf32>, tensor) -> (tensor<1x4x1xf32>, tensor<1x4x1xf32>) + return %0#1, %0#0 : tensor<1x4x1xf32>, tensor<1x4x1xf32> + } +} + +// CHECK: func.func @main(%arg0: tensor<1x4x1xf32>) -> (tensor<1x4x1xf32>, tensor<1x4x1xf32>) { +// CHECK-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK-NEXT: %0 = stablehlo.slice %arg0 [0:1, 0:1, 0:1] : (tensor<1x4x1xf32>) -> tensor<1x1x1xf32> +// CHECK-NEXT: %1 = stablehlo.reshape %0 : (tensor<1x1x1xf32>) -> tensor<1xf32> +// CHECK-NEXT: %2 = stablehlo.broadcast_in_dim %1, dims = [0] : (tensor<1xf32>) -> tensor<3xf32> +// CHECK-NEXT: %3 = stablehlo.add %2, %2 : tensor<3xf32> +// CHECK-NEXT: %4 = stablehlo.reduce(%3 init: %cst) applies stablehlo.add across dimensions = [0] : (tensor<3xf32>, tensor) -> tensor +// CHECK-NEXT: %5 = stablehlo.reshape %4 : (tensor) -> tensor<1x1x1xf32> +// CHECK-NEXT: %6 = stablehlo.pad %5, %cst, low = [0, 0, 0], high = [0, 3, 0], interior = [0, 0, 0] : (tensor<1x1x1xf32>, tensor) -> tensor<1x4x1xf32> +// CHECK-NEXT: return %6, %arg0 : tensor<1x4x1xf32>, tensor<1x4x1xf32> +// CHECK-NEXT: }