diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index e60f40d17b..cbf7d15e93 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -22704,36 +22704,107 @@ struct ScatterMultiplySimplify final } } - auto status = - detectConstantSetindexScatterOp(scatterOp, /*allowedMultipleUses*/ - false, - /*onlyConstantZerosAllowed*/ - true, nullptr); - if (!status.ok()) + if (!scatterOp.getUniqueIndices()) { + return failure(); + } + + bool isAllZeros = false, isAllOnes = false; + + SplatElementsAttr constSetIndexValue = nullptr; + auto status = detectConstantSetindexScatterOp( + scatterOp, /*allowedMultipleUses*/ false, + [&isAllZeros, &isAllOnes](mlir::Value input) { + isAllZeros = matchPattern(input, m_AnyZeroFloat()) || + matchPattern(input, m_Zero()); + if (!isAllZeros) { + isAllOnes = matchPattern(input, m_OneFloat()) || + matchPattern(input, m_One()); + } + return isAllZeros || isAllOnes; + }, + constSetIndexValue); + if (!status.ok()) { return rewriter.notifyMatchFailure(op, status.message()); + } SmallVector sliceSizes = computeGatherSliceSizes(scatterOp); - auto gatheredValues = stablehlo::GatherOp::create( - rewriter, op.getLoc(), otherValue, scatterOp.getScatterIndices(), - getGatherDims(rewriter.getContext(), - scatterOp.getScatterDimensionNumbers()), - rewriter.getDenseI64ArrayAttr(sliceSizes), - scatterOp.getIndicesAreSortedAttr()); + if (isAllZeros) { // non scattered values before zeros + auto gatheredValues = stablehlo::GatherOp::create( + rewriter, op.getLoc(), otherValue, scatterOp.getScatterIndices(), + getGatherDims(rewriter.getContext(), + scatterOp.getScatterDimensionNumbers()), + rewriter.getDenseI64ArrayAttr(sliceSizes), + scatterOp.getIndicesAreSortedAttr()); + + Value mulRhs; + if (constSetIndexValue) { + mulRhs = stablehlo::ConstantOp::create( + rewriter, op.getLoc(), + constSetIndexValue.resizeSplat( + cast(gatheredValues.getType()))); + } else { + mulRhs = scatterOp.getUpdates()[0]; + } + auto newUpdates = + stablehlo::MulOpCreate(rewriter, op.getLoc(), gatheredValues, mulRhs); - auto newUpdates = stablehlo::MulOp::create( - rewriter, op.getLoc(), gatheredValues, scatterOp.getUpdates()[0]); + auto newScatterOp = stablehlo::ScatterOp::create( + rewriter, op.getLoc(), scatterOp.getResultTypes(), + scatterOp.getInputs(), scatterOp.getScatterIndices(), + ValueRange(newUpdates), scatterOp.getScatterDimensionNumbersAttr(), + scatterOp.getIndicesAreSortedAttr(), + scatterOp.getUniqueIndicesAttr()); - auto newScatterOp = stablehlo::ScatterOp::create( - rewriter, op.getLoc(), scatterOp.getResultTypes(), - scatterOp.getInputs(), scatterOp.getScatterIndices(), - ValueRange(newUpdates), scatterOp.getScatterDimensionNumbersAttr(), - scatterOp.getIndicesAreSortedAttr(), scatterOp.getUniqueIndicesAttr()); - newScatterOp.getUpdateComputation().takeBody( - scatterOp.getUpdateComputation()); - rewriter.replaceOp(op, newScatterOp); + auto &updateRegion = newScatterOp.getUpdateComputation(); + auto *block = rewriter.createBlock(&updateRegion); + auto elemType = cast(scatterOp.getResultTypes()[0]) + .getElementType(); + auto argType = RankedTensorType::get({}, elemType); + block->addArgument(argType, op.getLoc()); + block->addArgument(argType, op.getLoc()); + rewriter.setInsertionPointToStart(block); + stablehlo::ReturnOp::create(rewriter, op.getLoc(), block->getArgument(1)); - return success(); + rewriter.replaceOp(op, newScatterOp); + return success(); + } + + if (isAllOnes) { // non-scattered values stay as is + auto newScatterOp = stablehlo::ScatterOp::create( + rewriter, op.getLoc(), scatterOp.getResultTypes(), + ValueRange(otherValue), scatterOp.getScatterIndices(), + scatterOp.getUpdates(), scatterOp.getScatterDimensionNumbersAttr(), + scatterOp.getIndicesAreSortedAttr(), + scatterOp.getUniqueIndicesAttr()); + + auto &updateRegion = newScatterOp.getUpdateComputation(); + auto *block = rewriter.createBlock(&updateRegion); + auto elemType = cast(scatterOp.getResultTypes()[0]) + .getElementType(); + auto argType = RankedTensorType::get({}, elemType); + block->addArgument(argType, op.getLoc()); + block->addArgument(argType, op.getLoc()); + rewriter.setInsertionPointToStart(block); + + Value mulRhs; + if (constSetIndexValue) { + mulRhs = stablehlo::ConstantOp::create( + rewriter, op.getLoc(), + constSetIndexValue.resizeSplat( + RankedTensorType::get({}, elemType))); + } else { + mulRhs = block->getArgument(1); + } + auto mulOp = stablehlo::MulOp::create(rewriter, op.getLoc(), + block->getArgument(0), mulRhs); + stablehlo::ReturnOp::create(rewriter, op.getLoc(), mulOp.getResult()); + + rewriter.replaceOp(op, newScatterOp); + return success(); + } + + return failure(); } }; @@ -22807,10 +22878,9 @@ struct UnaryElementwiseScatterSimplify final if (!scatterOp) return rewriter.notifyMatchFailure(op, "not a scatter op"); - DenseElementsAttr scatterInputAttr; auto status = detectConstantSetindexScatterOp( - scatterOp, false, /*onlyConstantZerosAllowed*/ false, - &scatterInputAttr); + scatterOp, false, + [](mlir::Value input) { return matchPattern(input, m_Constant()); }); if (!status.ok()) { return rewriter.notifyMatchFailure(op, status.message()); } diff --git a/src/enzyme_ad/jax/Passes/StructuredTensors.cpp b/src/enzyme_ad/jax/Passes/StructuredTensors.cpp index 49cbd706f4..fa4402a8e4 100644 --- a/src/enzyme_ad/jax/Passes/StructuredTensors.cpp +++ b/src/enzyme_ad/jax/Passes/StructuredTensors.cpp @@ -12,8 +12,15 @@ namespace enzyme { absl::Status detectConstantSetindexScatterOp(stablehlo::ScatterOp scatterOp, bool allowedMultipleUses, - bool onlyConstantZerosAllowed, - DenseElementsAttr *constAttr) { + InputValidatorFn inputValidator) { + SplatElementsAttr constSetIndexValue = nullptr; + return detectConstantSetindexScatterOp(scatterOp, allowedMultipleUses, + inputValidator, constSetIndexValue); +} + +absl::Status detectConstantSetindexScatterOp( + stablehlo::ScatterOp scatterOp, bool allowedMultipleUses, + InputValidatorFn inputValidator, SplatElementsAttr &constSetIndexValue) { if (scatterOp.getInputs().size() != 1) { return absl::UnimplementedError( "Detection not implemented for scatter op with >1 input."); @@ -26,20 +33,18 @@ absl::Status detectConstantSetindexScatterOp(stablehlo::ScatterOp scatterOp, auto checkCommonScatterOp = mlir::stablehlo::CheckCommonScatterOp(scatterOp); - if (!checkCommonScatterOp.isSetindexScatter) { + if (!checkCommonScatterOp.isSetindexScatter && + !checkCommonScatterOp.isConstantSetindexScatter) { return absl::InvalidArgumentError("ScatterOp is not a setindex op."); } + if (checkCommonScatterOp.isConstantSetindexScatter) { + constSetIndexValue = checkCommonScatterOp.constant; + } + auto input = scatterOp.getInputs()[0]; - if (onlyConstantZerosAllowed) { - if (matchPattern(input, m_AnyZeroFloat()) || - matchPattern(input, m_Zero())) { - return absl::OkStatus(); - } - } else { - if (matchPattern(input, m_Constant(constAttr))) { - return absl::OkStatus(); - } + if (inputValidator(input)) { + return absl::OkStatus(); } return absl::InvalidArgumentError( @@ -49,7 +54,11 @@ absl::Status detectConstantSetindexScatterOp(stablehlo::ScatterOp scatterOp, // TODO: detect batched diagonal tensors absl::Status detectDiagonalTensor(stablehlo::ScatterOp scatterOp, mlir::Value *outUpdates) { - auto status = detectConstantSetindexScatterOp(scatterOp, true, true, nullptr); + auto status = + detectConstantSetindexScatterOp(scatterOp, true, [](mlir::Value input) { + return matchPattern(input, m_AnyZeroFloat()) || + matchPattern(input, m_Zero()); + }); if (!status.ok()) return status; diff --git a/src/enzyme_ad/jax/Passes/StructuredTensors.h b/src/enzyme_ad/jax/Passes/StructuredTensors.h index cdde530cf2..200bc08dc7 100644 --- a/src/enzyme_ad/jax/Passes/StructuredTensors.h +++ b/src/enzyme_ad/jax/Passes/StructuredTensors.h @@ -7,15 +7,21 @@ #include "llvm/ADT/SetVector.h" +#include #include namespace mlir { namespace enzyme { +using InputValidatorFn = std::function; + +absl::Status detectConstantSetindexScatterOp( + stablehlo::ScatterOp scatterOp, bool allowedMultipleUses, + InputValidatorFn inputValidator, SplatElementsAttr &constSetIndexValue); + absl::Status detectConstantSetindexScatterOp(stablehlo::ScatterOp scatterOp, bool allowedMultipleUses, - bool onlyConstantZerosAllowed, - DenseElementsAttr *constAttr); + InputValidatorFn inputValidator); absl::Status detectDiagonalTensor(stablehlo::ScatterOp scatterOp, mlir::Value *outUpdates); diff --git a/src/enzyme_ad/jax/Utils.cpp b/src/enzyme_ad/jax/Utils.cpp index 9d1387e8ed..9c756d3a47 100644 --- a/src/enzyme_ad/jax/Utils.cpp +++ b/src/enzyme_ad/jax/Utils.cpp @@ -1422,26 +1422,44 @@ getGatherDims(mlir::MLIRContext *ctx, scatterDimNumbers.getIndexVectorDim()); } -bool isSetindexBlock(mlir::Block *block) { - if (block->getNumArguments() != 2) +bool isSetindexBlockHelper( + mlir::Block *block, + std::function fn) { + if (block->getNumArguments() != 2) { return false; - - auto updateValue = block->getArgument(1); + } // The block should have exactly one operation (the return) - if (block->getOperations().size() != 1) + if (block->getOperations().size() != 1) { return false; + } auto &returnOp = block->front(); auto stablehloReturnOp = dyn_cast(returnOp); - if (!stablehloReturnOp) + if (!stablehloReturnOp) { return false; + } - if (stablehloReturnOp.getNumOperands() != 1) + if (stablehloReturnOp.getNumOperands() != 1) { return false; + } + + return fn(stablehloReturnOp, block->getArgument(1)); +} - // The returned value should be the update value (second argument) - return stablehloReturnOp.getOperand(0) == updateValue; +bool isSetindexBlock(mlir::Block *block) { + return isSetindexBlockHelper( + block, [](stablehlo::ReturnOp retOp, Value updateValue) { + return retOp.getOperand(0) == updateValue; + }); +} + +bool isConstantSetindexBlock(mlir::Block *block, + mlir::SplatElementsAttr &constant) { + return isSetindexBlockHelper( + block, [&constant](stablehlo::ReturnOp retOp, Value updateValue) { + return matchPattern(retOp.getOperand(0), m_Constant(&constant)); + }); } SmallVector computeGatherSliceSizes(stablehlo::ScatterOp &scatterOp) { diff --git a/src/enzyme_ad/jax/Utils.h b/src/enzyme_ad/jax/Utils.h index 48c14c16e9..4493b61788 100644 --- a/src/enzyme_ad/jax/Utils.h +++ b/src/enzyme_ad/jax/Utils.h @@ -942,6 +942,8 @@ getGatherDims(mlir::MLIRContext *ctx, stablehlo::ScatterDimensionNumbersAttr scatterDimNumbers); bool isSetindexBlock(mlir::Block *block); +bool isConstantSetindexBlock(mlir::Block *block, + mlir::SplatElementsAttr &constant); // rhs is only considered if commutative is false template @@ -1067,6 +1069,8 @@ struct CheckCommonReduceOp { struct CheckCommonScatterOp { public: bool isSetindexScatter; + bool isConstantSetindexScatter; + bool isAddScatter; bool isMinScatter; bool isMaxScatter; @@ -1087,6 +1091,7 @@ struct CheckCommonScatterOp { if (!updateComputation.hasOneBlock()) { isSetindexScatter = false; + isConstantSetindexScatter = false; isAddScatter = false; isMinScatter = false; isMaxScatter = false; @@ -1105,6 +1110,7 @@ struct CheckCommonScatterOp { auto &block = updateComputation.front(); isSetindexScatter = isSetindexBlock(&block); + isConstantSetindexScatter = isConstantSetindexBlock(&block, constant); isAddScatter = isOnlyOpBlock(&block); isMulScatter = isOnlyOpBlock(&block); isMinScatter = isOnlyOpBlock(&block); diff --git a/test/lit_tests/mulscatter.mlir b/test/lit_tests/mulscatter.mlir index 407d89b299..af70b8e479 100644 --- a/test/lit_tests/mulscatter.mlir +++ b/test/lit_tests/mulscatter.mlir @@ -11,7 +11,7 @@ func.func @main(%arg0: tensor<4xi64>, %arg1: tensor<6xi64>, %arg2: tensor<1024x1 %4 = stablehlo.reshape %2 : (tensor<6x4xi64>) -> tensor<24x1xi64> %5 = stablehlo.concatenate %3, %4, dim = 1 : (tensor<24x1xi64>, tensor<24x1xi64>) -> tensor<24x2xi64> %6 = stablehlo.subtract %5, %c : tensor<24x2xi64> - %7 = "stablehlo.scatter"(%cst_0, %6, %cst) <{scatter_dimension_numbers = #stablehlo.scatter}> ({ + %7 = "stablehlo.scatter"(%cst_0, %6, %cst) <{scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ ^bb0(%arg3: tensor, %arg4: tensor): stablehlo.return %arg4 : tensor }) : (tensor<1024x1024xf32>, tensor<24x2xi64>, tensor<24xf32>) -> tensor<1024x1024xf32> @@ -27,7 +27,7 @@ func.func @main(%arg0: tensor<4xi64>, %arg1: tensor<6xi64>, %arg2: tensor<1024x1 // CHECK: %[[arg2_T:.*]] = stablehlo.transpose %arg2, dims = [1, 0] : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32> // CHECK: %[[GATHER:.*]] = "stablehlo.gather"(%[[arg2_T]], %[[SCATTER_INDICES:.*]]) <{dimension_numbers = #stablehlo.gather, slice_sizes = array}> : (tensor<1024x1024xf32>, tensor<24x2xi64>) -> tensor<24xf32> // CHECK: %[[MUL:.*]] = stablehlo.multiply %[[GATHER]], %[[CST]] : tensor<24xf32> -// CHECK: %[[SCATTER:.*]] = "stablehlo.scatter"(%[[CST_1]], %[[SCATTER_INDICES]], %[[MUL]]) <{scatter_dimension_numbers = #stablehlo.scatter}> ({ +// CHECK: %[[SCATTER:.*]] = "stablehlo.scatter"(%[[CST_1]], %[[SCATTER_INDICES]], %[[MUL]]) <{scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ // CHECK: %[[RESULT:.*]] = stablehlo.transpose %[[SCATTER]], dims = [1, 0] : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32> // CHECK: return %[[RESULT]] : tensor<1024x1024xf32> // CHECK: } diff --git a/test/lit_tests/mulscatterones.mlir b/test/lit_tests/mulscatterones.mlir new file mode 100644 index 0000000000..cdd6145b6a --- /dev/null +++ b/test/lit_tests/mulscatterones.mlir @@ -0,0 +1,35 @@ +// RUN: enzymexlamlir-opt %s --enzyme-hlo-opt | FileCheck %s + +func.func @main(%arg0: tensor<4xi64>, %arg1: tensor<6xi64>, %arg2: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> { + %cst = stablehlo.constant dense<2.000000e+00> : tensor<24xf32> + %c = stablehlo.constant dense<1> : tensor<24x2xi64> + %cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<1024x1024xf32> + %0 = stablehlo.transpose %arg2, dims = [1, 0] : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32> + %1 = stablehlo.concatenate %arg0, %arg0, %arg0, %arg0, %arg0, %arg0, dim = 0 : (tensor<4xi64>, tensor<4xi64>, tensor<4xi64>, tensor<4xi64>, tensor<4xi64>, tensor<4xi64>) -> tensor<24xi64> + %2 = stablehlo.broadcast_in_dim %arg1, dims = [0] : (tensor<6xi64>) -> tensor<6x4xi64> + %3 = stablehlo.reshape %1 : (tensor<24xi64>) -> tensor<24x1xi64> + %4 = stablehlo.reshape %2 : (tensor<6x4xi64>) -> tensor<24x1xi64> + %5 = stablehlo.concatenate %3, %4, dim = 1 : (tensor<24x1xi64>, tensor<24x1xi64>) -> tensor<24x2xi64> + %6 = stablehlo.subtract %5, %c : tensor<24x2xi64> + %7 = "stablehlo.scatter"(%cst_0, %6, %cst) <{scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg3: tensor, %arg4: tensor): + stablehlo.return %arg4 : tensor + }) : (tensor<1024x1024xf32>, tensor<24x2xi64>, tensor<24xf32>) -> tensor<1024x1024xf32> + %8 = stablehlo.multiply %7, %0 : tensor<1024x1024xf32> + %9 = stablehlo.transpose %8, dims = [1, 0] : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32> + return %9 : tensor<1024x1024xf32> +} + +// CHECK: func.func @main(%arg0: tensor<4xi64>, %arg1: tensor<6xi64>, %arg2: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> { +// CHECK: %[[S_CST:.*]] = stablehlo.constant dense<2.000000e+00> : tensor +// CHECK: %[[CST:.*]] = stablehlo.constant dense<2.000000e+00> : tensor<24xf32> +// CHECK: %[[CST_0:.*]] = stablehlo.constant dense<1> : tensor<24x2xi64> +// CHECK: %[[arg2_T:.*]] = stablehlo.transpose %arg2, dims = [1, 0] : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32> +// CHECK: %[[SCATTER:.*]] = "stablehlo.scatter"(%[[arg2_T]], %[[SCATTER_INDICES:.*]], %[[CST]]) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ +// CHECK: ^bb0(%[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): +// CHECK: %[[MUL:.*]] = stablehlo.multiply %[[ARG3]], %[[S_CST]] : tensor +// CHECK: stablehlo.return %[[MUL]] : tensor +// CHECK: }) +// CHECK: %[[RESULT:.*]] = stablehlo.transpose %[[SCATTER]], dims = [1, 0] : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32> +// CHECK: return %[[RESULT]] : tensor<1024x1024xf32> +// CHECK: } diff --git a/test/lit_tests/unaryscatter.mlir b/test/lit_tests/unaryscatter.mlir index d607e04c16..753156db99 100644 --- a/test/lit_tests/unaryscatter.mlir +++ b/test/lit_tests/unaryscatter.mlir @@ -88,7 +88,7 @@ func.func @convertscatter(%arg0: tensor<5x4xf32>, %arg1: tensor<5xui32>) -> tens %5 = stablehlo.reshape %3 : (tensor<5xi64>) -> tensor<5x1xi64> %6 = stablehlo.concatenate %4, %5, dim = 1 : (tensor<5x1xi64>, tensor<5x1xi64>) -> tensor<5x2xi64> %7 = stablehlo.remainder %6, %c : tensor<5x2xi64> - %8 = "stablehlo.scatter"(%c_3, %7, %c_1) <{scatter_dimension_numbers = #stablehlo.scatter}> ({ + %8 = "stablehlo.scatter"(%c_3, %7, %c_1) <{scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ ^bb0(%arg2: tensor, %arg3: tensor): stablehlo.return %arg3 : tensor }) : (tensor<4x5xi1>, tensor<5x2xi64>, tensor<5xi1>) -> tensor<4x5xi1> @@ -111,8 +111,8 @@ func.func @convertscatter(%arg0: tensor<5x4xf32>, %arg1: tensor<5xui32>) -> tens // CHECK-NEXT: %5 = stablehlo.reshape %3 : (tensor<5xi64>) -> tensor<5x1xi64> // CHECK-NEXT: %6 = stablehlo.concatenate %4, %5, dim = 1 : (tensor<5x1xi64>, tensor<5x1xi64>) -> tensor<5x2xi64> // CHECK-NEXT: %7 = stablehlo.remainder %6, %c : tensor<5x2xi64> -// CHECK-NEXT: %8 = "stablehlo.gather"(%0, %7) <{dimension_numbers = #stablehlo.gather, slice_sizes = array}> : (tensor<4x5xf32>, tensor<5x2xi64>) -> tensor<5xf32> -// CHECK-NEXT: %9 = "stablehlo.scatter"(%cst, %7, %8) <{scatter_dimension_numbers = #stablehlo.scatter}> ({ +// CHECK-NEXT: %8 = "stablehlo.gather"(%0, %7) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<4x5xf32>, tensor<5x2xi64>) -> tensor<5xf32> +// CHECK-NEXT: %9 = "stablehlo.scatter"(%cst, %7, %8) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ // CHECK-NEXT: ^bb0(%arg2: tensor, %arg3: tensor): // CHECK-NEXT: stablehlo.return %arg3 : tensor // CHECK-NEXT: }) : (tensor<4x5xf32>, tensor<5x2xi64>, tensor<5xf32>) -> tensor<4x5xf32>