diff --git a/src/enzyme_ad/jax/Passes/LowerEnzymeProbProg.cpp b/src/enzyme_ad/jax/Passes/LowerEnzymeProbProg.cpp index 9e0a47e06..8a5c0e767 100644 --- a/src/enzyme_ad/jax/Passes/LowerEnzymeProbProg.cpp +++ b/src/enzyme_ad/jax/Passes/LowerEnzymeProbProg.cpp @@ -1590,14 +1590,18 @@ struct DotOpConversion : public OpConversionPattern { auto lhs = adaptor.getLhs(); auto rhs = adaptor.getRhs(); auto resultType = cast(op.getResult().getType()); - auto lhsType = cast(lhs.getType()); + + auto lhsBatching = op.getLhsBatchingDimensions(); + auto rhsBatching = op.getRhsBatchingDimensions(); + auto lhsContracting = op.getLhsContractingDimensions(); + auto rhsContracting = op.getRhsContractingDimensions(); auto dotDimensionNumbers = stablehlo::DotDimensionNumbersAttr::get( rewriter.getContext(), - /*lhs_batching_dimensions=*/{}, - /*rhs_batching_dimensions=*/{}, - /*lhs_contracting_dimensions=*/{0}, - /*rhs_contracting_dimensions=*/{0}); + SmallVector(lhsBatching.begin(), lhsBatching.end()), + SmallVector(rhsBatching.begin(), rhsBatching.end()), + SmallVector(lhsContracting.begin(), lhsContracting.end()), + SmallVector(rhsContracting.begin(), rhsContracting.end())); auto dotOp = stablehlo::DotGeneralOp::create( rewriter, op.getLoc(), resultType, lhs, rhs, dotDimensionNumbers, @@ -1609,6 +1613,51 @@ struct DotOpConversion : public OpConversionPattern { } }; +// Reference: +// https://github.com/jax-ml/jax/blob/e9b487238f0cfe932200bae842d26826f19ba2bc/jax/_src/lax/other.py#L262 +struct LogAddExpOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + std::string backend; + LogAddExpOpConversion(std::string backend, TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), backend(backend) { + } + + LogicalResult + matchAndRewrite(enzyme::LogAddExpOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto lhs = adaptor.getLhs(); + auto rhs = adaptor.getRhs(); + auto resultType = cast(op.getResult().getType()); + + auto amax = + stablehlo::MaxOp::create(rewriter, op.getLoc(), resultType, lhs, rhs); + auto delta = stablehlo::SubtractOp::create(rewriter, op.getLoc(), + resultType, lhs, rhs); + auto isNaN = + stablehlo::CompareOp::create(rewriter, op.getLoc(), delta, delta, + stablehlo::ComparisonDirection::NE); + auto nanResult = + stablehlo::AddOp::create(rewriter, op.getLoc(), resultType, lhs, rhs); + auto absDelta = + stablehlo::AbsOp::create(rewriter, op.getLoc(), resultType, delta); + auto negAbsDelta = + stablehlo::NegOp::create(rewriter, op.getLoc(), resultType, absDelta); + auto expNegAbsDelta = stablehlo::ExpOp::create(rewriter, op.getLoc(), + resultType, negAbsDelta); + auto log1pResult = stablehlo::Log1pOp::create(rewriter, op.getLoc(), + resultType, expNegAbsDelta); + auto normalResult = stablehlo::AddOp::create(rewriter, op.getLoc(), + resultType, amax, log1pResult); + auto result = stablehlo::SelectOp::create(rewriter, op.getLoc(), resultType, + isNaN, nanResult, normalResult); + + rewriter.replaceOp(op, result); + return success(); + } +}; + struct RandomOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -2239,17 +2288,17 @@ struct GetFlattenedSamplesFromTraceOpConversion } }; -struct LoopOpConversion : public OpConversionPattern { +struct ForLoopOpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; std::string backend; - LoopOpConversion(std::string backend, TypeConverter &typeConverter, - MLIRContext *context, PatternBenefit benefit = 1) + ForLoopOpConversion(std::string backend, TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), backend(backend) { } LogicalResult - matchAndRewrite(enzyme::LoopOp op, OpAdaptor adaptor, + matchAndRewrite(enzyme::ForLoopOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector initVals = {adaptor.getLowerBound()}; initVals.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end()); @@ -2394,16 +2443,17 @@ struct LowerProbProgToStableHLOPass target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); + target.addIllegalOp(); target.addLegalOp(); RewritePatternSet patterns(context); patterns.add( - backend, typeConverter, context); + LogAddExpOpConversion, UnflattenSliceOpConversion, + ForLoopOpConversion>(backend, typeConverter, context); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { diff --git a/test/lit_tests/probprog/hmc.mlir b/test/lit_tests/probprog/hmc.mlir index f7c0deda9..fb121c94e 100644 --- a/test/lit_tests/probprog/hmc.mlir +++ b/test/lit_tests/probprog/hmc.mlir @@ -1,4 +1,4 @@ -// RUN: enzymexlamlir-opt %s --pass-pipeline="builtin.module(arith-raise,lower-probprog-to-stablehlo{backend=cpu},canonicalize,outline-enzyme-regions,enzyme,canonicalize,remove-unnecessary-enzyme-ops,canonicalize,enzyme-simplify-math,cse,inline,cse)" | FileCheck %s --check-prefix=CPU --dump-input=always +// RUN: enzymexlamlir-opt %s --pass-pipeline="builtin.module(arith-raise,lower-probprog-to-stablehlo{backend=cpu},canonicalize,outline-enzyme-regions,enzyme,canonicalize,inline,remove-unnecessary-enzyme-ops,canonicalize,enzyme-simplify-math,cse)" | FileCheck %s --check-prefix=CPU module { func.func private @normal(%arg0: tensor<2xui64>, %arg1: tensor, %arg2: tensor) -> (tensor<2xui64>, tensor) { @@ -15,6 +15,7 @@ module { %1:2 = enzyme.sample @normal(%0#0, %0#1, %arg2) {logpdf = @logpdf, name = "t", symbol = #enzyme.symbol<2>} : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) return %1#0, %1#1 : tensor<2xui64>, tensor } + func.func @hmc(%arg0: tensor<2xui64>, %arg1: tensor, %arg2: tensor) -> (!enzyme.Trace, tensor, tensor<2xui64>) { %cst = arith.constant dense<5.000000e-02> : tensor %cst_0 = arith.constant dense<1> : tensor @@ -29,50 +30,53 @@ module { %1 = enzyme.getFlattenedSamplesFromTrace %0 {selection = [[#enzyme.symbol<1>], [#enzyme.symbol<2>]]} : tensor<2xf64> %2 = enzyme.getWeightFromTrace %0 : tensor %3 = arith.negf %2 : tensor - %output_rng_state, %result = enzyme.random %arg0, %cst_4, %cst_7 {rng_distribution = #enzyme} : (tensor<2xui64>, tensor, tensor<2x2xf64>) -> (tensor<2xui64>, tensor<2xf64>) - %4 = enzyme.cholesky_solve %cst_7, %result : (tensor<2x2xf64>, tensor<2xf64>) -> tensor<2xf64> - %5 = enzyme.dot %result, %4 : (tensor<2xf64>, tensor<2xf64>) -> tensor - %6 = arith.mulf %5, %cst_2 : tensor - %7 = arith.addf %3, %6 : tensor - %8:2 = enzyme.autodiff_region(%1, %cst_3) { + %4 = enzyme.cholesky_solve %cst_7, %cst_7 : (tensor<2x2xf64>, tensor<2x2xf64>) -> tensor<2x2xf64> + %output_rng_state, %result = enzyme.random %arg0, %cst_4, %cst_3 {rng_distribution = #enzyme} : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor<2xf64>) + %5 = enzyme.dot %4, %result {lhs_batching_dimensions = array, lhs_contracting_dimensions = array, rhs_batching_dimensions = array, rhs_contracting_dimensions = array} : (tensor<2x2xf64>, tensor<2xf64>) -> tensor<2xf64> + %6 = enzyme.dot %cst_7, %5 {lhs_batching_dimensions = array, lhs_contracting_dimensions = array, rhs_batching_dimensions = array, rhs_contracting_dimensions = array} : (tensor<2x2xf64>, tensor<2xf64>) -> tensor<2xf64> + %7 = enzyme.dot %5, %6 {lhs_batching_dimensions = array, lhs_contracting_dimensions = array, rhs_batching_dimensions = array, rhs_contracting_dimensions = array} : (tensor<2xf64>, tensor<2xf64>) -> tensor + %8 = arith.mulf %7, %cst_2 : tensor + %9 = arith.addf %3, %8 : tensor + %10:2 = enzyme.autodiff_region(%1, %cst_3) { ^bb0(%arg3: tensor<2xf64>): - %23:3 = func.call @test.update(%0, %arg3, %output_rng_state, %arg1, %arg2) : (!enzyme.Trace, tensor<2xf64>, tensor<2xui64>, tensor, tensor) -> (!enzyme.Trace, tensor, tensor<2xui64>) - %24 = arith.negf %23#1 : tensor - enzyme.yield %24, %23#2 : tensor, tensor<2xui64> + %25:3 = func.call @test.update(%0, %arg3, %output_rng_state, %arg1, %arg2) : (!enzyme.Trace, tensor<2xf64>, tensor<2xui64>, tensor, tensor) -> (!enzyme.Trace, tensor, tensor<2xui64>) + %26 = arith.negf %25#1 : tensor + enzyme.yield %26, %25#2 : tensor, tensor<2xui64> } attributes {activity = [#enzyme], ret_activity = [#enzyme, #enzyme]} : (tensor<2xf64>, tensor) -> (tensor<2xui64>, tensor<2xf64>) - %9 = "enzyme.broadcast"(%cst_6) <{shape = array}> : (tensor) -> tensor<2xf64> - %10 = "enzyme.broadcast"(%cst) <{shape = array}> : (tensor) -> tensor<2xf64> - %11:4 = enzyme.loop(%cst_1 : tensor) to(%cst_5 : tensor) step(%cst_0 : tensor) iter_args(%1, %result, %8#1, %8#0 : tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xui64>) -> tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xui64> { + %11 = "enzyme.broadcast"(%cst_6) <{shape = array}> : (tensor) -> tensor<2xf64> + %12 = "enzyme.broadcast"(%cst) <{shape = array}> : (tensor) -> tensor<2xf64> + %13:4 = enzyme.for_loop(%cst_1 : tensor) to(%cst_5 : tensor) step(%cst_0 : tensor) iter_args(%1, %5, %10#1, %10#0 : tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xui64>) -> tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xui64> { ^bb0(%arg3: tensor, %arg4: tensor<2xf64>, %arg5: tensor<2xf64>, %arg6: tensor<2xf64>, %arg7: tensor<2xui64>): - %23 = arith.mulf %10, %arg6 : tensor<2xf64> - %24 = arith.subf %arg5, %23 : tensor<2xf64> - %25 = enzyme.cholesky_solve %cst_7, %24 : (tensor<2x2xf64>, tensor<2xf64>) -> tensor<2xf64> - %26 = arith.mulf %9, %25 : tensor<2xf64> - %27 = arith.addf %arg4, %26 : tensor<2xf64> - %28:2 = enzyme.autodiff_region(%27, %cst_3) { + %25 = arith.mulf %12, %arg6 : tensor<2xf64> + %26 = arith.subf %arg5, %25 : tensor<2xf64> + %27 = enzyme.dot %cst_7, %26 {lhs_batching_dimensions = array, lhs_contracting_dimensions = array, rhs_batching_dimensions = array, rhs_contracting_dimensions = array} : (tensor<2x2xf64>, tensor<2xf64>) -> tensor<2xf64> + %28 = arith.mulf %11, %27 : tensor<2xf64> + %29 = arith.addf %arg4, %28 : tensor<2xf64> + %30:2 = enzyme.autodiff_region(%29, %cst_3) { ^bb0(%arg8: tensor<2xf64>): - %31:3 = func.call @test.update(%0, %arg8, %arg7, %arg1, %arg2) : (!enzyme.Trace, tensor<2xf64>, tensor<2xui64>, tensor, tensor) -> (!enzyme.Trace, tensor, tensor<2xui64>) - %32 = arith.negf %31#1 : tensor - enzyme.yield %32, %31#2 : tensor, tensor<2xui64> + %33:3 = func.call @test.update(%0, %arg8, %arg7, %arg1, %arg2) : (!enzyme.Trace, tensor<2xf64>, tensor<2xui64>, tensor, tensor) -> (!enzyme.Trace, tensor, tensor<2xui64>) + %34 = arith.negf %33#1 : tensor + enzyme.yield %34, %33#2 : tensor, tensor<2xui64> } attributes {activity = [#enzyme], ret_activity = [#enzyme, #enzyme]} : (tensor<2xf64>, tensor) -> (tensor<2xui64>, tensor<2xf64>) - %29 = arith.mulf %10, %28#1 : tensor<2xf64> - %30 = arith.subf %24, %29 : tensor<2xf64> - enzyme.yield %27, %30, %28#1, %28#0 : tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xui64> + %31 = arith.mulf %12, %30#1 : tensor<2xf64> + %32 = arith.subf %26, %31 : tensor<2xf64> + enzyme.yield %29, %32, %30#1, %30#0 : tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xui64> } - %12:3 = call @test.update(%0, %11#0, %11#3, %arg1, %arg2) : (!enzyme.Trace, tensor<2xf64>, tensor<2xui64>, tensor, tensor) -> (!enzyme.Trace, tensor, tensor<2xui64>) - %13 = arith.negf %12#1 : tensor - %14 = enzyme.cholesky_solve %cst_7, %11#1 : (tensor<2x2xf64>, tensor<2xf64>) -> tensor<2xf64> - %15 = enzyme.dot %11#1, %14 : (tensor<2xf64>, tensor<2xf64>) -> tensor - %16 = arith.mulf %15, %cst_2 : tensor - %17 = arith.addf %13, %16 : tensor - %18 = arith.subf %7, %17 : tensor - %19 = math.exp %18 : tensor - %20 = arith.minimumf %19, %cst_3 : tensor - %output_rng_state_8, %result_9 = enzyme.random %12#2, %cst_4, %cst_3 {rng_distribution = #enzyme} : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) - %21 = arith.cmpf olt, %result_9, %20 : tensor - %22 = enzyme.selectTrace %21, %12#0, %0 : tensor - return %22, %21, %output_rng_state_8 : !enzyme.Trace, tensor, tensor<2xui64> + %14:3 = call @test.update(%0, %13#0, %13#3, %arg1, %arg2) : (!enzyme.Trace, tensor<2xf64>, tensor<2xui64>, tensor, tensor) -> (!enzyme.Trace, tensor, tensor<2xui64>) + %15 = arith.negf %14#1 : tensor + %16 = enzyme.dot %cst_7, %13#1 {lhs_batching_dimensions = array, lhs_contracting_dimensions = array, rhs_batching_dimensions = array, rhs_contracting_dimensions = array} : (tensor<2x2xf64>, tensor<2xf64>) -> tensor<2xf64> + %17 = enzyme.dot %13#1, %16 {lhs_batching_dimensions = array, lhs_contracting_dimensions = array, rhs_batching_dimensions = array, rhs_contracting_dimensions = array} : (tensor<2xf64>, tensor<2xf64>) -> tensor + %18 = arith.mulf %17, %cst_2 : tensor + %19 = arith.addf %15, %18 : tensor + %20 = arith.subf %9, %19 : tensor + %21 = math.exp %20 : tensor + %22 = arith.minimumf %21, %cst_3 : tensor + %output_rng_state_8, %result_9 = enzyme.random %14#2, %cst_4, %cst_3 {rng_distribution = #enzyme} : (tensor<2xui64>, tensor, tensor) -> (tensor<2xui64>, tensor) + %23 = arith.cmpf olt, %result_9, %22 : tensor + %24 = enzyme.selectTrace %23, %14#0, %0 : tensor + return %24, %23, %output_rng_state_8 : !enzyme.Trace, tensor, tensor<2xui64> } + func.func @test.update(%arg0: !enzyme.Trace, %arg1: tensor<2xf64>, %arg2: tensor<2xui64>, %arg3: tensor, %arg4: tensor) -> (!enzyme.Trace, tensor, tensor<2xui64>) { %cst = arith.constant dense<0.000000e+00> : tensor %0 = enzyme.initTrace : !enzyme.Trace @@ -91,20 +95,20 @@ module { } // CPU: func.func @hmc(%arg0: tensor<2xui64>, %arg1: tensor, %arg2: tensor) -> (!enzyme.Trace, tensor, tensor<2xui64>) { -// CPU-NEXT: %cst = arith.constant dense<0.000000e+00> : tensor +// CPU-NEXT: %cst = stablehlo.constant dense<0.000000e+00> : tensor +// CPU-NEXT: %cst_0 = arith.constant dense<0.000000e+00> : tensor // CPU-NEXT: %c = stablehlo.constant dense<4607182418800017408> : tensor -// CPU-NEXT: %c_0 = stablehlo.constant dense<12> : tensor -// CPU-NEXT: %cst_1 = stablehlo.constant dense<1.4142135623730951> : tensor<2xf64> -// CPU-NEXT: %cst_2 = stablehlo.constant dense<2.000000e+00> : tensor<2xf64> -// CPU-NEXT: %cst_3 = stablehlo.constant dense<1.000000e+00> : tensor<2xf64> -// CPU-NEXT: %c_4 = stablehlo.constant dense<4607182418800017408> : tensor<2xui64> -// CPU-NEXT: %c_5 = stablehlo.constant dense<12> : tensor<2xui64> -// CPU-NEXT: %cst_6 = stablehlo.constant dense<5.000000e-02> : tensor -// CPU-NEXT: %c_7 = stablehlo.constant dense<1> : tensor -// CPU-NEXT: %c_8 = stablehlo.constant dense<0> : tensor -// CPU-NEXT: %cst_9 = stablehlo.constant dense<5.000000e-01> : tensor -// CPU-NEXT: %cst_10 = stablehlo.constant dense<1.000000e+00> : tensor -// CPU-NEXT: %cst_11 = stablehlo.constant dense<0.000000e+00> : tensor +// CPU-NEXT: %c_1 = stablehlo.constant dense<12> : tensor +// CPU-NEXT: %cst_2 = stablehlo.constant dense<1.4142135623730951> : tensor<2xf64> +// CPU-NEXT: %cst_3 = stablehlo.constant dense<2.000000e+00> : tensor<2xf64> +// CPU-NEXT: %cst_4 = stablehlo.constant dense<1.000000e+00> : tensor<2xf64> +// CPU-NEXT: %c_5 = stablehlo.constant dense<4607182418800017408> : tensor<2xui64> +// CPU-NEXT: %c_6 = stablehlo.constant dense<12> : tensor<2xui64> +// CPU-NEXT: %cst_7 = stablehlo.constant dense<5.000000e-02> : tensor +// CPU-NEXT: %c_8 = stablehlo.constant dense<1> : tensor +// CPU-NEXT: %c_9 = stablehlo.constant dense<0> : tensor +// CPU-NEXT: %cst_10 = stablehlo.constant dense<5.000000e-01> : tensor +// CPU-NEXT: %cst_11 = stablehlo.constant dense<1.000000e+00> : tensor // CPU-NEXT: %c_12 = stablehlo.constant dense<10> : tensor // CPU-NEXT: %cst_13 = stablehlo.constant dense<1.000000e-01> : tensor // CPU-NEXT: %cst_14 = stablehlo.constant dense<{{\[}}[1.000000e+00, 0.000000e+00], [0.000000e+00, 1.000000e+00]{{\]}}> : tensor<2x2xf64> @@ -112,77 +116,67 @@ module { // CPU-NEXT: %1 = enzyme.getFlattenedSamplesFromTrace %0 {selection = {{\[}}[#enzyme.symbol<1>], [#enzyme.symbol<2>]{{\]}}} : tensor<2xf64> // CPU-NEXT: %2 = enzyme.getWeightFromTrace %0 : tensor // CPU-NEXT: %3 = stablehlo.negate %2 : tensor +// CPU-NEXT: %4 = stablehlo.cholesky %cst_14, lower = true : tensor<2x2xf64> +// CPU-NEXT: %5 = "stablehlo.triangular_solve"(%4, %cst_14) <{left_side = true, lower = true, transpose_a = #stablehlo, unit_diagonal = false}> : (tensor<2x2xf64>, tensor<2x2xf64>) -> tensor<2x2xf64> +// CPU-NEXT: %6 = "stablehlo.triangular_solve"(%4, %5) <{left_side = true, lower = true, transpose_a = #stablehlo, unit_diagonal = false}> : (tensor<2x2xf64>, tensor<2x2xf64>) -> tensor<2x2xf64> // CPU-NEXT: %output_state, %output = stablehlo.rng_bit_generator %arg0, algorithm = DEFAULT : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2xui64>) -// CPU-NEXT: %4 = stablehlo.shift_right_logical %output, %c_5 : tensor<2xui64> -// CPU-NEXT: %5 = stablehlo.or %4, %c_4 : tensor<2xui64> -// CPU-NEXT: %6 = stablehlo.bitcast_convert %5 : (tensor<2xui64>) -> tensor<2xf64> -// CPU-NEXT: %7 = stablehlo.subtract %6, %cst_3 : tensor<2xf64> -// CPU-NEXT: %8 = stablehlo.multiply %7, %cst_2 : tensor<2xf64> -// CPU-NEXT: %9 = stablehlo.subtract %8, %cst_3 : tensor<2xf64> -// CPU-NEXT: %10 = chlo.erf_inv %9 : tensor<2xf64> -> tensor<2xf64> -// CPU-NEXT: %11 = stablehlo.multiply %10, %cst_1 : tensor<2xf64> -// CPU-NEXT: %12 = stablehlo.cholesky %cst_14, lower = true : tensor<2x2xf64> -// CPU-NEXT: %13 = stablehlo.dot_general %12, %11, contracting_dims = [1] x [0] : (tensor<2x2xf64>, tensor<2xf64>) -> tensor<2xf64> -// CPU-NEXT: %14 = stablehlo.broadcast_in_dim %cst_11, dims = [] : (tensor) -> tensor<2xf64> -// CPU-NEXT: %15 = stablehlo.add %14, %13 : tensor<2xf64> -// CPU-NEXT: %16 = stablehlo.reshape %15 : (tensor<2xf64>) -> tensor<2x1xf64> -// CPU-NEXT: %17 = "stablehlo.triangular_solve"(%12, %16) <{left_side = true, lower = true, transpose_a = #stablehlo, unit_diagonal = false}> : (tensor<2x2xf64>, tensor<2x1xf64>) -> tensor<2x1xf64> -// CPU-NEXT: %18 = "stablehlo.triangular_solve"(%12, %17) <{left_side = true, lower = true, transpose_a = #stablehlo, unit_diagonal = false}> : (tensor<2x2xf64>, tensor<2x1xf64>) -> tensor<2x1xf64> -// CPU-NEXT: %19 = stablehlo.reshape %18 : (tensor<2x1xf64>) -> tensor<2xf64> -// CPU-NEXT: %20 = stablehlo.dot_general %15, %19, contracting_dims = [0] x [0] : (tensor<2xf64>, tensor<2xf64>) -> tensor -// CPU-NEXT: %21 = stablehlo.multiply %20, %cst_9 : tensor -// CPU-NEXT: %22 = stablehlo.add %3, %21 : tensor -// CPU-NEXT: %23 = stablehlo.reshape %cst : (tensor) -> tensor<1xf64> -// CPU-NEXT: %24 = stablehlo.pad %23, %cst, low = [1], high = [0], interior = [0] : (tensor<1xf64>, tensor) -> tensor<2xf64> -// CPU-NEXT: %25 = stablehlo.pad %23, %cst, low = [0], high = [1], interior = [0] : (tensor<1xf64>, tensor) -> tensor<2xf64> -// CPU-NEXT: %26 = arith.addf %24, %25 : tensor<2xf64> -// CPU-NEXT: %27 = stablehlo.broadcast_in_dim %cst_13, dims = [] : (tensor) -> tensor<2xf64> -// CPU-NEXT: %28 = stablehlo.broadcast_in_dim %cst_6, dims = [] : (tensor) -> tensor<2xf64> -// CPU-NEXT: %29:5 = stablehlo.while(%iterArg = %c_8, %iterArg_17 = %1, %iterArg_18 = %15, %iterArg_19 = %26, %iterArg_20 = %output_state) : tensor, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xui64> +// CPU-NEXT: %7 = stablehlo.shift_right_logical %output, %c_6 : tensor<2xui64> +// CPU-NEXT: %8 = stablehlo.or %7, %c_5 : tensor<2xui64> +// CPU-NEXT: %9 = stablehlo.bitcast_convert %8 : (tensor<2xui64>) -> tensor<2xf64> +// CPU-NEXT: %10 = stablehlo.subtract %9, %cst_4 : tensor<2xf64> +// CPU-NEXT: %11 = stablehlo.multiply %10, %cst_3 : tensor<2xf64> +// CPU-NEXT: %12 = stablehlo.subtract %11, %cst_4 : tensor<2xf64> +// CPU-NEXT: %13 = chlo.erf_inv %12 : tensor<2xf64> -> tensor<2xf64> +// CPU-NEXT: %14 = stablehlo.multiply %13, %cst_2 : tensor<2xf64> +// CPU-NEXT: %15 = stablehlo.dot_general %6, %14, contracting_dims = [1] x [0] : (tensor<2x2xf64>, tensor<2xf64>) -> tensor<2xf64> +// CPU-NEXT: %16 = stablehlo.dot_general %cst_14, %15, contracting_dims = [1] x [0] : (tensor<2x2xf64>, tensor<2xf64>) -> tensor<2xf64> +// CPU-NEXT: %17 = stablehlo.dot_general %15, %16, contracting_dims = [0] x [0] : (tensor<2xf64>, tensor<2xf64>) -> tensor +// CPU-NEXT: %18 = stablehlo.multiply %17, %cst_10 : tensor +// CPU-NEXT: %19 = stablehlo.add %3, %18 : tensor +// CPU-NEXT: %20 = stablehlo.reshape %cst_0 : (tensor) -> tensor<1xf64> +// CPU-NEXT: %21 = stablehlo.pad %20, %cst_0, low = [1], high = [0], interior = [0] : (tensor<1xf64>, tensor) -> tensor<2xf64> +// CPU-NEXT: %22 = stablehlo.pad %20, %cst_0, low = [0], high = [1], interior = [0] : (tensor<1xf64>, tensor) -> tensor<2xf64> +// CPU-NEXT: %23 = arith.addf %21, %22 : tensor<2xf64> +// CPU-NEXT: %24 = stablehlo.broadcast_in_dim %cst_13, dims = [] : (tensor) -> tensor<2xf64> +// CPU-NEXT: %25 = stablehlo.broadcast_in_dim %cst_7, dims = [] : (tensor) -> tensor<2xf64> +// CPU-NEXT: %26:5 = stablehlo.while(%iterArg = %c_9, %iterArg_17 = %1, %iterArg_18 = %15, %iterArg_19 = %23, %iterArg_20 = %output_state) : tensor, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xui64> // CPU-NEXT: cond { -// CPU-NEXT: %57 = stablehlo.compare LT, %iterArg, %c_12 : (tensor, tensor) -> tensor -// CPU-NEXT: stablehlo.return %57 : tensor +// CPU-NEXT: %50 = stablehlo.compare LT, %iterArg, %c_12 : (tensor, tensor) -> tensor +// CPU-NEXT: stablehlo.return %50 : tensor // CPU-NEXT: } do { -// CPU-NEXT: %57 = stablehlo.multiply %28, %iterArg_19 : tensor<2xf64> -// CPU-NEXT: %58 = stablehlo.subtract %iterArg_18, %57 : tensor<2xf64> -// CPU-NEXT: %59 = stablehlo.reshape %58 : (tensor<2xf64>) -> tensor<2x1xf64> -// CPU-NEXT: %60 = "stablehlo.triangular_solve"(%12, %59) <{left_side = true, lower = true, transpose_a = #stablehlo, unit_diagonal = false}> : (tensor<2x2xf64>, tensor<2x1xf64>) -> tensor<2x1xf64> -// CPU-NEXT: %61 = "stablehlo.triangular_solve"(%12, %60) <{left_side = true, lower = true, transpose_a = #stablehlo, unit_diagonal = false}> : (tensor<2x2xf64>, tensor<2x1xf64>) -> tensor<2x1xf64> -// CPU-NEXT: %62 = stablehlo.reshape %61 : (tensor<2x1xf64>) -> tensor<2xf64> -// CPU-NEXT: %63 = stablehlo.multiply %27, %62 : tensor<2xf64> -// CPU-NEXT: %64 = stablehlo.add %iterArg_17, %63 : tensor<2xf64> -// CPU-NEXT: %65 = stablehlo.multiply %28, %26 : tensor<2xf64> -// CPU-NEXT: %66 = stablehlo.subtract %58, %65 : tensor<2xf64> -// CPU-NEXT: %67 = stablehlo.add %iterArg, %c_7 : tensor -// CPU-NEXT: stablehlo.return %67, %64, %66, %26, %iterArg_20 : tensor, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xui64> +// CPU-NEXT: %50 = stablehlo.multiply %25, %iterArg_19 : tensor<2xf64> +// CPU-NEXT: %51 = stablehlo.subtract %iterArg_18, %50 : tensor<2xf64> +// CPU-NEXT: %52 = stablehlo.dot_general %cst_14, %51, contracting_dims = [1] x [0] : (tensor<2x2xf64>, tensor<2xf64>) -> tensor<2xf64> +// CPU-NEXT: %53 = stablehlo.multiply %24, %52 : tensor<2xf64> +// CPU-NEXT: %54 = stablehlo.add %iterArg_17, %53 : tensor<2xf64> +// CPU-NEXT: %55 = stablehlo.multiply %25, %23 : tensor<2xf64> +// CPU-NEXT: %56 = stablehlo.subtract %51, %55 : tensor<2xf64> +// CPU-NEXT: %57 = stablehlo.add %iterArg, %c_8 : tensor +// CPU-NEXT: stablehlo.return %57, %54, %56, %23, %iterArg_20 : tensor, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xui64> // CPU-NEXT: } -// CPU-NEXT: %30 = enzyme.initTrace : !enzyme.Trace -// CPU-NEXT: %31 = stablehlo.slice %29#1 [0:1] : (tensor<2xf64>) -> tensor<1xf64> +// CPU-NEXT: %27 = enzyme.initTrace : !enzyme.Trace +// CPU-NEXT: %28 = stablehlo.slice %26#1 [0:1] : (tensor<2xf64>) -> tensor<1xf64> +// CPU-NEXT: %29 = stablehlo.reshape %28 : (tensor<1xf64>) -> tensor +// CPU-NEXT: %30 = enzyme.addSampleToTrace(%29 : tensor) into %27 {symbol = #enzyme.symbol<1>} +// CPU-NEXT: %31 = stablehlo.slice %26#1 [1:2] : (tensor<2xf64>) -> tensor<1xf64> // CPU-NEXT: %32 = stablehlo.reshape %31 : (tensor<1xf64>) -> tensor -// CPU-NEXT: %33 = enzyme.addSampleToTrace(%32 : tensor) into %30 {symbol = #enzyme.symbol<1>} -// CPU-NEXT: %34 = stablehlo.slice %29#1 [1:2] : (tensor<2xf64>) -> tensor<1xf64> -// CPU-NEXT: %35 = stablehlo.reshape %34 : (tensor<1xf64>) -> tensor -// CPU-NEXT: %36 = stablehlo.add %cst_11, %cst_11 : tensor -// CPU-NEXT: %37 = enzyme.addSampleToTrace(%35 : tensor) into %33 {symbol = #enzyme.symbol<2>} -// CPU-NEXT: %38 = enzyme.addWeightToTrace(%36 : tensor) into %37 -// CPU-NEXT: %39 = enzyme.addRetvalToTrace(%35 : tensor) into %38 -// CPU-NEXT: %40 = stablehlo.negate %36 : tensor -// CPU-NEXT: %41 = stablehlo.reshape %29#2 : (tensor<2xf64>) -> tensor<2x1xf64> -// CPU-NEXT: %42 = "stablehlo.triangular_solve"(%12, %41) <{left_side = true, lower = true, transpose_a = #stablehlo, unit_diagonal = false}> : (tensor<2x2xf64>, tensor<2x1xf64>) -> tensor<2x1xf64> -// CPU-NEXT: %43 = "stablehlo.triangular_solve"(%12, %42) <{left_side = true, lower = true, transpose_a = #stablehlo, unit_diagonal = false}> : (tensor<2x2xf64>, tensor<2x1xf64>) -> tensor<2x1xf64> -// CPU-NEXT: %44 = stablehlo.reshape %43 : (tensor<2x1xf64>) -> tensor<2xf64> -// CPU-NEXT: %45 = stablehlo.dot_general %29#2, %44, contracting_dims = [0] x [0] : (tensor<2xf64>, tensor<2xf64>) -> tensor -// CPU-NEXT: %46 = stablehlo.multiply %45, %cst_9 : tensor -// CPU-NEXT: %47 = stablehlo.add %40, %46 : tensor -// CPU-NEXT: %48 = stablehlo.subtract %22, %47 : tensor -// CPU-NEXT: %49 = stablehlo.exponential %48 : tensor -// CPU-NEXT: %50 = stablehlo.minimum %49, %cst_10 : tensor -// CPU-NEXT: %output_state_15, %output_16 = stablehlo.rng_bit_generator %29#4, algorithm = DEFAULT : (tensor<2xui64>) -> (tensor<2xui64>, tensor) -// CPU-NEXT: %51 = stablehlo.shift_right_logical %output_16, %c_0 : tensor -// CPU-NEXT: %52 = stablehlo.or %51, %c : tensor -// CPU-NEXT: %53 = stablehlo.bitcast_convert %52 : (tensor) -> tensor -// CPU-NEXT: %54 = stablehlo.subtract %53, %cst_10 : tensor -// CPU-NEXT: %55 = stablehlo.compare LT, %54, %50, FLOAT : (tensor, tensor) -> tensor -// CPU-NEXT: %56 = enzyme.selectTrace %55, %39, %0 : tensor -// CPU-NEXT: return %56, %55, %output_state_15 : !enzyme.Trace, tensor, tensor<2xui64> +// CPU-NEXT: %33 = enzyme.addSampleToTrace(%32 : tensor) into %30 {symbol = #enzyme.symbol<2>} +// CPU-NEXT: %34 = enzyme.addWeightToTrace(%cst : tensor) into %33 +// CPU-NEXT: %35 = enzyme.addRetvalToTrace(%32 : tensor) into %34 +// CPU-NEXT: %36 = stablehlo.negate %cst : tensor +// CPU-NEXT: %37 = stablehlo.dot_general %cst_14, %26#2, contracting_dims = [1] x [0] : (tensor<2x2xf64>, tensor<2xf64>) -> tensor<2xf64> +// CPU-NEXT: %38 = stablehlo.dot_general %26#2, %37, contracting_dims = [0] x [0] : (tensor<2xf64>, tensor<2xf64>) -> tensor +// CPU-NEXT: %39 = stablehlo.multiply %38, %cst_10 : tensor +// CPU-NEXT: %40 = stablehlo.add %36, %39 : tensor +// CPU-NEXT: %41 = stablehlo.subtract %19, %40 : tensor +// CPU-NEXT: %42 = stablehlo.exponential %41 : tensor +// CPU-NEXT: %43 = stablehlo.minimum %42, %cst_11 : tensor +// CPU-NEXT: %output_state_15, %output_16 = stablehlo.rng_bit_generator %26#4, algorithm = DEFAULT : (tensor<2xui64>) -> (tensor<2xui64>, tensor) +// CPU-NEXT: %44 = stablehlo.shift_right_logical %output_16, %c_1 : tensor +// CPU-NEXT: %45 = stablehlo.or %44, %c : tensor +// CPU-NEXT: %46 = stablehlo.bitcast_convert %45 : (tensor) -> tensor +// CPU-NEXT: %47 = stablehlo.subtract %46, %cst_11 : tensor +// CPU-NEXT: %48 = stablehlo.compare LT, %47, %43, FLOAT : (tensor, tensor) -> tensor +// CPU-NEXT: %49 = enzyme.selectTrace %48, %35, %0 : tensor +// CPU-NEXT: return %49, %48, %output_state_15 : !enzyme.Trace, tensor, tensor<2xui64> // CPU-NEXT: } diff --git a/test/lit_tests/probprog/log_add_exp.mlir b/test/lit_tests/probprog/log_add_exp.mlir new file mode 100644 index 000000000..4431be423 --- /dev/null +++ b/test/lit_tests/probprog/log_add_exp.mlir @@ -0,0 +1,21 @@ +// RUN: enzymexlamlir-opt %s --pass-pipeline="builtin.module(lower-probprog-to-stablehlo{backend=cpu})" | FileCheck %s --check-prefix=CPU + +module { + // CPU: func.func @test(%arg0: tensor<10xf64>, %arg1: tensor<10xf64>) -> tensor<10xf64> { + // CPU-NEXT: %0 = stablehlo.maximum %arg0, %arg1 : tensor<10xf64> + // CPU-NEXT: %1 = stablehlo.subtract %arg0, %arg1 : tensor<10xf64> + // CPU-NEXT: %2 = stablehlo.compare NE, %1, %1 : (tensor<10xf64>, tensor<10xf64>) -> tensor<10xi1> + // CPU-NEXT: %3 = stablehlo.add %arg0, %arg1 : tensor<10xf64> + // CPU-NEXT: %4 = stablehlo.abs %1 : tensor<10xf64> + // CPU-NEXT: %5 = stablehlo.negate %4 : tensor<10xf64> + // CPU-NEXT: %6 = stablehlo.exponential %5 : tensor<10xf64> + // CPU-NEXT: %7 = stablehlo.log_plus_one %6 : tensor<10xf64> + // CPU-NEXT: %8 = stablehlo.add %0, %7 : tensor<10xf64> + // CPU-NEXT: %9 = stablehlo.select %2, %3, %8 : tensor<10xi1>, tensor<10xf64> + // CPU-NEXT: return %9 : tensor<10xf64> + // CPU-NEXT: } + func.func @test(%lhs: tensor<10xf64>, %rhs: tensor<10xf64>) -> tensor<10xf64> { + %result = enzyme.log_add_exp %lhs, %rhs : (tensor<10xf64>, tensor<10xf64>) -> tensor<10xf64> + return %result : tensor<10xf64> + } +} diff --git a/test/lit_tests/probprog/loop.mlir b/test/lit_tests/probprog/loop.mlir index 2f7aaabf9..ca1ae800b 100644 --- a/test/lit_tests/probprog/loop.mlir +++ b/test/lit_tests/probprog/loop.mlir @@ -23,7 +23,7 @@ module { %c1 = stablehlo.constant dense<1> : tensor %init = stablehlo.constant dense<0.0> : tensor - %result = enzyme.loop (%c0 : tensor) to (%n : tensor) step (%c1 : tensor) + %result = enzyme.for_loop (%c0 : tensor) to (%n : tensor) step (%c1 : tensor) iter_args(%init : tensor) -> tensor { ^bb0(%iv: tensor, %sum_iter: tensor): @@ -64,11 +64,11 @@ module { %c1 = stablehlo.constant dense<1> : tensor %init = stablehlo.constant dense<0.0> : tensor - %result = enzyme.loop (%c0 : tensor) to (%m : tensor) step (%c1 : tensor) + %result = enzyme.for_loop (%c0 : tensor) to (%m : tensor) step (%c1 : tensor) iter_args(%init : tensor) -> tensor { ^bb0(%i: tensor, %outer_sum_iter: tensor): - %inner_result = enzyme.loop (%c0 : tensor) to (%n : tensor) step (%c1 : tensor) + %inner_result = enzyme.for_loop (%c0 : tensor) to (%n : tensor) step (%c1 : tensor) iter_args(%outer_sum_iter : tensor) -> tensor { ^bb1(%j: tensor, %inner_sum_iter: tensor): @@ -108,7 +108,7 @@ module { %init_sum = stablehlo.constant dense<0.0> : tensor %init_prod = stablehlo.constant dense<1.0> : tensor - %sum, %prod = enzyme.loop (%c0 : tensor) to (%n : tensor) step (%c1 : tensor) + %sum, %prod = enzyme.for_loop (%c0 : tensor) to (%n : tensor) step (%c1 : tensor) iter_args(%init_sum, %init_prod : tensor, tensor) -> tensor, tensor { ^bb0(%iv: tensor, %s_iter: tensor, %p_iter: tensor): @@ -173,7 +173,7 @@ module { %init_sum = stablehlo.constant dense<0.0> : tensor %init_trace = enzyme.initTrace : !enzyme.Trace - %sum, %trace = enzyme.loop (%c0 : tensor) to (%n : tensor) step (%c1 : tensor) + %sum, %trace = enzyme.for_loop (%c0 : tensor) to (%n : tensor) step (%c1 : tensor) iter_args(%init_sum, %init_trace : tensor, !enzyme.Trace) -> tensor, !enzyme.Trace { ^bb0(%iv: tensor, %s_iter: tensor, %t_iter: !enzyme.Trace):