Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 155 additions & 2 deletions src/enzyme_ad/jax/Implementations/HLODerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -977,8 +977,6 @@ def : HLODerivative<"Atan2Op", (Op $x, $y),
(CheckedDiv (Sub (Mul $x, (Shadow $y)), (Mul $y, (Shadow $x))), (Add (Pow $x, (HLOConstantFP<"2">)), (Pow $y, (HLOConstantFP<"2"> $y))))
>;

def : HLOReadOnlyIdentityOp<"BroadcastInDimOp">;

// Gather Adjoint
def ResultDimensionNumbers : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
op.getDimensionNumbers();
Expand Down Expand Up @@ -1446,3 +1444,158 @@ def : HLODerivative<"FftOp",
Fft (Shadow $x), (FftType), (FftLength)
)
>;

def getBroadcastDimensionsWithBatch : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
auto bcastDims = op.getBroadcastDimensions();
SmallVector<int64_t> newBcastDims;
for (auto dim : bcastDims) {
newBcastDims.push_back(to_i64(dim));
}
if (gutils->width > 1) {
newBcastDims.insert(newBcastDims.begin(), gutils->width);
}
getI64Attr(builder, newBcastDims);
}]>;

def BroadcastDimsToReductionDims : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
SmallVector<int64_t> reduceDims;
auto outTy = cast<RankedTensorType>(op.getType());
auto bcastDims = op.getBroadcastDimensions();
auto inTy = cast<RankedTensorType>(op.getOperand().getType());

for (auto en : llvm::enumerate(outTy.getShape())) {
ssize_t bcastIdx = -1;
for (auto en2 : llvm::enumerate(bcastDims)) {
if (en2.value() == en.index()) {
bcastIdx = en2.index();
break;
}
}
if (bcastIdx != -1) {
if (en.value() != inTy.getShape()[bcastIdx]) {
reduceDims.push_back(en.index());
assert(inTy.getShape()[bcastIdx] == 1);
}
continue;
}
reduceDims.push_back(en.index());
}

if (gutils->width > 1) {
for (int i = 0; i < reduceDims.size(); i++) {
reduceDims[i]++;
}
}
getI64Attr(builder, reduceDims);
}]>;

def ReduceAddNullValue : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
auto operand = op->getOperand(0);
auto operandElemType = cast<RankedTensorType>(operand.getType()).getElementType();
auto bodyTy = RankedTensorType::get({}, operandElemType);
cast<AutoDiffTypeInterface>(
gutils->getShadowType(bodyTy)).createNullValue(builder, op.getLoc());
}]>;

def BroadcastDimensionsToInversePermutation : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
auto outTy = cast<RankedTensorType>(op.getType());
SmallVector<int64_t> perm(outTy.getRank());
SmallVector<int64_t> mapping(outTy.getRank(), -1);
for (auto [i, dim] : llvm::enumerate(op.getBroadcastDimensions())) {
mapping[to_i64(dim)] = i;
}

int next = op.getBroadcastDimensions().size();
for (int i = 0; i < outTy.getRank(); i++) {
if (mapping[i] == -1) {
mapping[i] = next++;
}
}

for (int i = 0; i < outTy.getRank(); i++) {
perm[mapping[i]] = i;
}
getI64Attr(builder, perm);
}]>;

def InsertDeletedReduceDimsType : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
SmallVector<int64_t> reduceDims;
auto outTy = cast<RankedTensorType>(op.getType());
auto bcastDims = op.getBroadcastDimensions();
auto inTy = cast<RankedTensorType>(op.getOperand().getType());
auto outShape = outTy.getShape();

for (auto en : llvm::enumerate(outTy.getShape())) {
ssize_t bcastIdx = -1;
for (auto en2 : llvm::enumerate(bcastDims)) {
if (en2.value() == en.index()) {
bcastIdx = en2.index();
break;
}
}
if (bcastIdx != -1) {
if (en.value() != inTy.getShape()[bcastIdx]) {
reduceDims.push_back(en.index());
assert(inTy.getShape()[bcastIdx] == 1);
}
continue;
}
reduceDims.push_back(en.index());
}

SmallVector<int64_t> reshapeShape(outTy.getRank(), -1);
for (auto [i, sz] : llvm::enumerate(outShape)) {
if (llvm::is_contained(reduceDims, i)) {
reshapeShape[i] = 1;
} else {
reshapeShape[i] = sz;
}
}

if (gutils->width > 1) {
reshapeShape.insert(reshapeShape.begin(), gutils->width);
}
RankedTensorType::get(reshapeShape, outTy.getElementType());
}]>;

def ReduceAdd : HLOInst<"ReduceOp", ")->getResult(0)", "createAddRegion(">;

def ResultTypeWithBatch : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
auto outTy = cast<RankedTensorType>(op.getType());
auto outShape = outTy.getShape();
SmallVector<int64_t> newShape(outShape.begin(), outShape.end());
if (gutils->width > 1) {
newShape.insert(newShape.begin(), gutils->width);
}
RankedTensorType::get(newShape, outTy.getElementType());
}]>;

def InputTypeWithBatch : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
auto inTy = cast<RankedTensorType>(op.getOperand().getType());
auto inShape = inTy.getShape();
SmallVector<int64_t> newShape(inShape.begin(), inShape.end());
if (gutils->width > 1) {
newShape.insert(newShape.begin(), gutils->width);
}
RankedTensorType::get(newShape, inTy.getElementType());
}]>;

def : HLODerivative<"BroadcastInDimOp", (Op $x),
[
(
Reshape
(InputTypeWithBatch),
(
Transpose
(
Reshape
(InsertDeletedReduceDimsType),
(ReduceAdd (DiffeRet), (ReduceAddNullValue), (BroadcastDimsToReductionDims))
),
(BroadcastDimensionsToInversePermutation)
)
)
],
(
BroadcastInDim (ResultTypeWithBatch), (Shadow $x), (getBroadcastDimensionsWithBatch)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cam you explicitly add batch tests? In principle the tablegen interface should automatically do batching on top of so we should check it doesn't accidentally conflict

)>;
111 changes: 0 additions & 111 deletions src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1133,116 +1133,6 @@ class AutoDiffWhileRev
}
};

class AutoDiffBroadcastInDimRev
: public ReverseAutoDiffOpInterface::ExternalModel<
AutoDiffBroadcastInDimRev, BroadcastInDimOp> {
public:
LogicalResult createReverseModeAdjoint(Operation *orig, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
auto op = cast<BroadcastInDimOp>(orig);
auto inTy = op.getOperand().getType();
auto outTy = op.getType();
auto inDiffe = gutils->diffe(op, builder);
gutils->zeroDiffe(op, builder);

SmallVector<int64_t> bcastDims(op.getBroadcastDimensions().begin(),
op.getBroadcastDimensions().end());

SmallVector<int64_t> reducedDims;
SmallVector<int64_t> iterShape;
for (auto en : llvm::enumerate(outTy.getShape())) {
ssize_t bcastIdx = -1;
for (auto en2 : llvm::enumerate(bcastDims)) {
if (en2.value() == en.index()) {
bcastIdx = en2.index();
break;
}
}
if (bcastIdx != -1) {
if (en.value() != inTy.getShape()[bcastIdx]) {
reducedDims.push_back(en.index());
assert(inTy.getShape()[bcastIdx] == 1);
} else {
iterShape.push_back(inTy.getShape()[bcastIdx]);
}
continue;
}
reducedDims.push_back(en.index());
}

SmallVector<int64_t> reshapedShape(outTy.getRank(), -1);
for (auto [i, sz] : llvm::enumerate(outTy.getShape())) {
if (llvm::is_contained(reducedDims, i)) {
reshapedShape[i] = 1;
} else {
reshapedShape[i] = sz;
}
}

SmallVector<int64_t> perm(outTy.getRank(), -1);
SmallVector<int64_t> mapping(outTy.getRank(), -1);
for (auto [i, dim] : llvm::enumerate(bcastDims)) {
mapping[dim] = i;
}

int next = bcastDims.size();
for (int i = 0; i < outTy.getRank(); i++) {
if (mapping[i] == -1) {
mapping[i] = next++;
}
}

for (int i = 0; i < outTy.getRank(); i++) {
perm[mapping[i]] = i;
}

auto reduceTy = RankedTensorType::get(iterShape, inTy.getElementType());
auto bodyTy = RankedTensorType::get({}, inTy.getElementType());

Value zero = cast<AutoDiffTypeInterface>(gutils->getShadowType(bodyTy))
.createNullValue(builder, op.getLoc());

auto red = builder.create<ReduceOp>(
op.getLoc(), TypeRange(gutils->getShadowType(reduceTy)), inDiffe, zero,
reducedDims);
red.getBody().push_back(new Block());
Block &body = red.getBody().front();
OpBuilder bodyBuilder(orig->getContext());
bodyBuilder.setInsertionPointToEnd(&body);

body.addArgument(bodyTy, op.getLoc());
body.addArgument(bodyTy, op.getLoc());
auto add = bodyBuilder.create<AddOp>(op.getLoc(), body.getArgument(0),
body.getArgument(1));
bodyBuilder.create<ReturnOp>(op.getLoc(), ValueRange(add));

// for simplicity we do grad -> reduce -> reshape (restore 1 dims) ->
// transpose -> reshape
// The repeated reshapes are then eliminated via `enzyme-hlo-opt`.
auto reshapedRed = builder.create<ReshapeOp>(
op.getLoc(),
RankedTensorType::get(reshapedShape, inTy.getElementType()),
red->getResult(0));
auto transposedVal =
builder.create<TransposeOp>(op.getLoc(), reshapedRed, perm);
auto res = builder.create<ReshapeOp>(
op.getLoc(), gutils->getShadowType(op.getOperand().getType()),
transposedVal);

gutils->addToDiffe(op.getOperand(), res, builder);
return success();
}

SmallVector<Value> cacheValues(Operation *orig,
MGradientUtilsReverse *gutils) const {
return {};
}

void createShadowValues(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils) const {}
};

class AutoDiffSliceRev
: public ReverseAutoDiffOpInterface::ExternalModel<AutoDiffSliceRev,
SliceOp> {
Expand Down Expand Up @@ -3635,7 +3525,6 @@ void mlir::enzyme::registerStableHLODialectAutoDiffInterface(
WhileOp::attachInterface<AutoDiffWhileRev>(*context);
ReduceOp::attachInterface<AutoDiffReduceCF<ReduceOp>>(*context);
WhileOp::attachInterface<AutoDiffReduceCF<WhileOp>>(*context);
BroadcastInDimOp::attachInterface<AutoDiffBroadcastInDimRev>(*context);
SliceOp::attachInterface<AutoDiffSliceRev>(*context);
ReduceOp::attachInterface<AutoDiffReduceRev>(*context);
ReduceWindowOp::attachInterface<AutoDiffReduceWindowRev>(*context);
Expand Down
32 changes: 32 additions & 0 deletions test/lit_tests/diffrules/stablehlo/broadcastindim3.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// RUN: enzymexlamlir-opt --enzyme --canonicalize --remove-unnecessary-enzyme-ops --arith-raise --inline --enzyme-hlo-opt %s | FileCheck %s

module {
func.func private @"Const{typeof(slicing)}(Main.slicing)_autodiff"(%arg0: tensor<1x4x1xf32>) -> (tensor<f32>, tensor<1x4x1xf32>) {
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<3xf32>
%0 = stablehlo.slice %arg0 [0:1, 0:1, 0:1] : (tensor<1x4x1xf32>) -> tensor<1x1x1xf32>
%1 = stablehlo.reshape %0 : (tensor<1x1x1xf32>) -> tensor<1xf32>
%2 = stablehlo.broadcast_in_dim %1, dims = [0] : (tensor<1xf32>) -> tensor<3xf32>
%3 = stablehlo.multiply %2, %cst_0 : tensor<3xf32>
%4 = stablehlo.multiply %3, %3 : tensor<3xf32>
%5 = stablehlo.reduce(%4 init: %cst) applies stablehlo.add across dimensions = [0] : (tensor<3xf32>, tensor<f32>) -> tensor<f32>
return %5, %arg0 : tensor<f32>, tensor<1x4x1xf32>
}
func.func @main(%arg0: tensor<1x4x1xf32>) -> (tensor<1x4x1xf32>, tensor<1x4x1xf32>) {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%0:2 = enzyme.autodiff @"Const{typeof(slicing)}(Main.slicing)_autodiff"(%arg0, %cst) {activity = [#enzyme<activity enzyme_active>], ret_activity = [#enzyme<activity enzyme_activenoneed>, #enzyme<activity enzyme_const>]} : (tensor<1x4x1xf32>, tensor<f32>) -> (tensor<1x4x1xf32>, tensor<1x4x1xf32>)
return %0#1, %0#0 : tensor<1x4x1xf32>, tensor<1x4x1xf32>
}
}

// CHECK: func.func @main(%arg0: tensor<1x4x1xf32>) -> (tensor<1x4x1xf32>, tensor<1x4x1xf32>) {
// CHECK-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK-NEXT: %0 = stablehlo.slice %arg0 [0:1, 0:1, 0:1] : (tensor<1x4x1xf32>) -> tensor<1x1x1xf32>
// CHECK-NEXT: %1 = stablehlo.reshape %0 : (tensor<1x1x1xf32>) -> tensor<1xf32>
// CHECK-NEXT: %2 = stablehlo.broadcast_in_dim %1, dims = [0] : (tensor<1xf32>) -> tensor<3xf32>
// CHECK-NEXT: %3 = stablehlo.add %2, %2 : tensor<3xf32>
// CHECK-NEXT: %4 = stablehlo.reduce(%3 init: %cst) applies stablehlo.add across dimensions = [0] : (tensor<3xf32>, tensor<f32>) -> tensor<f32>
// CHECK-NEXT: %5 = stablehlo.reshape %4 : (tensor<f32>) -> tensor<1x1x1xf32>
// CHECK-NEXT: %6 = stablehlo.pad %5, %cst, low = [0, 0, 0], high = [0, 3, 0], interior = [0, 0, 0] : (tensor<1x1x1xf32>, tensor<f32>) -> tensor<1x4x1xf32>
// CHECK-NEXT: return %6, %arg0 : tensor<1x4x1xf32>, tensor<1x4x1xf32>
// CHECK-NEXT: }
Loading