Skip to content

Commit f38cddf

Browse files
Update EnzymeAD/Enzyme to commit 032828cbfef50bfba41443baacc39989c203534b (#1577)
* Update EnzymeAD/Enzyme to commit 032828cbfef50bfba41443baacc39989c203534b Diff: EnzymeAD/Enzyme@0ce301a...032828c * fix fix --------- Co-authored-by: enzymead-bot[bot] <238314553+enzymead-bot[bot]@users.noreply.github.com> Co-authored-by: sbrantq <[email protected]>
1 parent f937a46 commit f38cddf

File tree

2 files changed

+43
-42
lines changed

2 files changed

+43
-42
lines changed

test/lit_tests/probprog/hmc.mlir

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -140,48 +140,49 @@ module {
140140
// CPU-NEXT: %28 = stablehlo.broadcast_in_dim %cst_6, dims = [] : (tensor<f64>) -> tensor<2xf64>
141141
// CPU-NEXT: %29:5 = stablehlo.while(%iterArg = %c_8, %iterArg_17 = %1, %iterArg_18 = %15, %iterArg_19 = %26, %iterArg_20 = %output_state) : tensor<i64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xui64>
142142
// CPU-NEXT: cond {
143-
// CPU-NEXT: %56 = stablehlo.compare LT, %iterArg, %c_12 : (tensor<i64>, tensor<i64>) -> tensor<i1>
144-
// CPU-NEXT: stablehlo.return %56 : tensor<i1>
143+
// CPU-NEXT: %57 = stablehlo.compare LT, %iterArg, %c_12 : (tensor<i64>, tensor<i64>) -> tensor<i1>
144+
// CPU-NEXT: stablehlo.return %57 : tensor<i1>
145145
// CPU-NEXT: } do {
146-
// CPU-NEXT: %56 = stablehlo.multiply %28, %iterArg_19 : tensor<2xf64>
147-
// CPU-NEXT: %57 = stablehlo.subtract %iterArg_18, %56 : tensor<2xf64>
148-
// CPU-NEXT: %58 = stablehlo.reshape %57 : (tensor<2xf64>) -> tensor<2x1xf64>
149-
// CPU-NEXT: %59 = "stablehlo.triangular_solve"(%12, %58) <{left_side = true, lower = true, transpose_a = #stablehlo<transpose NO_TRANSPOSE>, unit_diagonal = false}> : (tensor<2x2xf64>, tensor<2x1xf64>) -> tensor<2x1xf64>
150-
// CPU-NEXT: %60 = "stablehlo.triangular_solve"(%12, %59) <{left_side = true, lower = true, transpose_a = #stablehlo<transpose TRANSPOSE>, unit_diagonal = false}> : (tensor<2x2xf64>, tensor<2x1xf64>) -> tensor<2x1xf64>
151-
// CPU-NEXT: %61 = stablehlo.reshape %60 : (tensor<2x1xf64>) -> tensor<2xf64>
152-
// CPU-NEXT: %62 = stablehlo.multiply %27, %61 : tensor<2xf64>
153-
// CPU-NEXT: %63 = stablehlo.add %iterArg_17, %62 : tensor<2xf64>
154-
// CPU-NEXT: %64 = stablehlo.multiply %28, %26 : tensor<2xf64>
155-
// CPU-NEXT: %65 = stablehlo.subtract %57, %64 : tensor<2xf64>
156-
// CPU-NEXT: %66 = stablehlo.add %iterArg, %c_7 : tensor<i64>
157-
// CPU-NEXT: stablehlo.return %66, %63, %65, %26, %iterArg_20 : tensor<i64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xui64>
146+
// CPU-NEXT: %57 = stablehlo.multiply %28, %iterArg_19 : tensor<2xf64>
147+
// CPU-NEXT: %58 = stablehlo.subtract %iterArg_18, %57 : tensor<2xf64>
148+
// CPU-NEXT: %59 = stablehlo.reshape %58 : (tensor<2xf64>) -> tensor<2x1xf64>
149+
// CPU-NEXT: %60 = "stablehlo.triangular_solve"(%12, %59) <{left_side = true, lower = true, transpose_a = #stablehlo<transpose NO_TRANSPOSE>, unit_diagonal = false}> : (tensor<2x2xf64>, tensor<2x1xf64>) -> tensor<2x1xf64>
150+
// CPU-NEXT: %61 = "stablehlo.triangular_solve"(%12, %60) <{left_side = true, lower = true, transpose_a = #stablehlo<transpose TRANSPOSE>, unit_diagonal = false}> : (tensor<2x2xf64>, tensor<2x1xf64>) -> tensor<2x1xf64>
151+
// CPU-NEXT: %62 = stablehlo.reshape %61 : (tensor<2x1xf64>) -> tensor<2xf64>
152+
// CPU-NEXT: %63 = stablehlo.multiply %27, %62 : tensor<2xf64>
153+
// CPU-NEXT: %64 = stablehlo.add %iterArg_17, %63 : tensor<2xf64>
154+
// CPU-NEXT: %65 = stablehlo.multiply %28, %26 : tensor<2xf64>
155+
// CPU-NEXT: %66 = stablehlo.subtract %58, %65 : tensor<2xf64>
156+
// CPU-NEXT: %67 = stablehlo.add %iterArg, %c_7 : tensor<i64>
157+
// CPU-NEXT: stablehlo.return %67, %64, %66, %26, %iterArg_20 : tensor<i64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xui64>
158158
// CPU-NEXT: }
159-
// CPU-NEXT: %30 = stablehlo.slice %29#1 [0:1] : (tensor<2xf64>) -> tensor<1xf64>
160-
// CPU-NEXT: %31 = stablehlo.reshape %30 : (tensor<1xf64>) -> tensor<f64>
161-
// CPU-NEXT: %32 = enzyme.addSampleToTrace(%31 : tensor<f64>) into %0 {symbol = #enzyme.symbol<1>}
162-
// CPU-NEXT: %33 = stablehlo.slice %29#1 [1:2] : (tensor<2xf64>) -> tensor<1xf64>
163-
// CPU-NEXT: %34 = stablehlo.reshape %33 : (tensor<1xf64>) -> tensor<f64>
164-
// CPU-NEXT: %35 = stablehlo.add %cst_11, %cst_11 : tensor<f64>
165-
// CPU-NEXT: %36 = enzyme.addSampleToTrace(%34 : tensor<f64>) into %32 {symbol = #enzyme.symbol<2>}
166-
// CPU-NEXT: %37 = enzyme.addWeightToTrace(%35 : tensor<f64>) into %36
167-
// CPU-NEXT: %38 = enzyme.addRetvalToTrace(%34 : tensor<f64>) into %37
168-
// CPU-NEXT: %39 = stablehlo.negate %35 : tensor<f64>
169-
// CPU-NEXT: %40 = stablehlo.reshape %29#2 : (tensor<2xf64>) -> tensor<2x1xf64>
170-
// CPU-NEXT: %41 = "stablehlo.triangular_solve"(%12, %40) <{left_side = true, lower = true, transpose_a = #stablehlo<transpose NO_TRANSPOSE>, unit_diagonal = false}> : (tensor<2x2xf64>, tensor<2x1xf64>) -> tensor<2x1xf64>
171-
// CPU-NEXT: %42 = "stablehlo.triangular_solve"(%12, %41) <{left_side = true, lower = true, transpose_a = #stablehlo<transpose TRANSPOSE>, unit_diagonal = false}> : (tensor<2x2xf64>, tensor<2x1xf64>) -> tensor<2x1xf64>
172-
// CPU-NEXT: %43 = stablehlo.reshape %42 : (tensor<2x1xf64>) -> tensor<2xf64>
173-
// CPU-NEXT: %44 = stablehlo.dot_general %29#2, %43, contracting_dims = [0] x [0] : (tensor<2xf64>, tensor<2xf64>) -> tensor<f64>
174-
// CPU-NEXT: %45 = stablehlo.multiply %44, %cst_9 : tensor<f64>
175-
// CPU-NEXT: %46 = stablehlo.add %39, %45 : tensor<f64>
176-
// CPU-NEXT: %47 = stablehlo.subtract %22, %46 : tensor<f64>
177-
// CPU-NEXT: %48 = stablehlo.exponential %47 : tensor<f64>
178-
// CPU-NEXT: %49 = stablehlo.minimum %48, %cst_10 : tensor<f64>
159+
// CPU-NEXT: %30 = enzyme.initTrace : !enzyme.Trace
160+
// CPU-NEXT: %31 = stablehlo.slice %29#1 [0:1] : (tensor<2xf64>) -> tensor<1xf64>
161+
// CPU-NEXT: %32 = stablehlo.reshape %31 : (tensor<1xf64>) -> tensor<f64>
162+
// CPU-NEXT: %33 = enzyme.addSampleToTrace(%32 : tensor<f64>) into %30 {symbol = #enzyme.symbol<1>}
163+
// CPU-NEXT: %34 = stablehlo.slice %29#1 [1:2] : (tensor<2xf64>) -> tensor<1xf64>
164+
// CPU-NEXT: %35 = stablehlo.reshape %34 : (tensor<1xf64>) -> tensor<f64>
165+
// CPU-NEXT: %36 = stablehlo.add %cst_11, %cst_11 : tensor<f64>
166+
// CPU-NEXT: %37 = enzyme.addSampleToTrace(%35 : tensor<f64>) into %33 {symbol = #enzyme.symbol<2>}
167+
// CPU-NEXT: %38 = enzyme.addWeightToTrace(%36 : tensor<f64>) into %37
168+
// CPU-NEXT: %39 = enzyme.addRetvalToTrace(%35 : tensor<f64>) into %38
169+
// CPU-NEXT: %40 = stablehlo.negate %36 : tensor<f64>
170+
// CPU-NEXT: %41 = stablehlo.reshape %29#2 : (tensor<2xf64>) -> tensor<2x1xf64>
171+
// CPU-NEXT: %42 = "stablehlo.triangular_solve"(%12, %41) <{left_side = true, lower = true, transpose_a = #stablehlo<transpose NO_TRANSPOSE>, unit_diagonal = false}> : (tensor<2x2xf64>, tensor<2x1xf64>) -> tensor<2x1xf64>
172+
// CPU-NEXT: %43 = "stablehlo.triangular_solve"(%12, %42) <{left_side = true, lower = true, transpose_a = #stablehlo<transpose TRANSPOSE>, unit_diagonal = false}> : (tensor<2x2xf64>, tensor<2x1xf64>) -> tensor<2x1xf64>
173+
// CPU-NEXT: %44 = stablehlo.reshape %43 : (tensor<2x1xf64>) -> tensor<2xf64>
174+
// CPU-NEXT: %45 = stablehlo.dot_general %29#2, %44, contracting_dims = [0] x [0] : (tensor<2xf64>, tensor<2xf64>) -> tensor<f64>
175+
// CPU-NEXT: %46 = stablehlo.multiply %45, %cst_9 : tensor<f64>
176+
// CPU-NEXT: %47 = stablehlo.add %40, %46 : tensor<f64>
177+
// CPU-NEXT: %48 = stablehlo.subtract %22, %47 : tensor<f64>
178+
// CPU-NEXT: %49 = stablehlo.exponential %48 : tensor<f64>
179+
// CPU-NEXT: %50 = stablehlo.minimum %49, %cst_10 : tensor<f64>
179180
// CPU-NEXT: %output_state_15, %output_16 = stablehlo.rng_bit_generator %29#4, algorithm = DEFAULT : (tensor<2xui64>) -> (tensor<2xui64>, tensor<ui64>)
180-
// CPU-NEXT: %50 = stablehlo.shift_right_logical %output_16, %c_0 : tensor<ui64>
181-
// CPU-NEXT: %51 = stablehlo.or %50, %c : tensor<ui64>
182-
// CPU-NEXT: %52 = stablehlo.bitcast_convert %51 : (tensor<ui64>) -> tensor<f64>
183-
// CPU-NEXT: %53 = stablehlo.subtract %52, %cst_10 : tensor<f64>
184-
// CPU-NEXT: %54 = stablehlo.compare LT, %53, %49, FLOAT : (tensor<f64>, tensor<f64>) -> tensor<i1>
185-
// CPU-NEXT: %55 = enzyme.selectTrace %54, %38, %0 : tensor<i1>
186-
// CPU-NEXT: return %55, %54, %output_state_15 : !enzyme.Trace, tensor<i1>, tensor<2xui64>
181+
// CPU-NEXT: %51 = stablehlo.shift_right_logical %output_16, %c_0 : tensor<ui64>
182+
// CPU-NEXT: %52 = stablehlo.or %51, %c : tensor<ui64>
183+
// CPU-NEXT: %53 = stablehlo.bitcast_convert %52 : (tensor<ui64>) -> tensor<f64>
184+
// CPU-NEXT: %54 = stablehlo.subtract %53, %cst_10 : tensor<f64>
185+
// CPU-NEXT: %55 = stablehlo.compare LT, %54, %50, FLOAT : (tensor<f64>, tensor<f64>) -> tensor<i1>
186+
// CPU-NEXT: %56 = enzyme.selectTrace %55, %39, %0 : tensor<i1>
187+
// CPU-NEXT: return %56, %55, %output_state_15 : !enzyme.Trace, tensor<i1>, tensor<2xui64>
187188
// CPU-NEXT: }

workspace.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
JAX_COMMIT = "e9609ce42f272e0a4e908b9e16ea81239e76385c"
22
JAX_SHA256 = ""
33

4-
ENZYME_COMMIT = "0ce301aedef3ca040c8703cb1b7d340ed4a58271"
4+
ENZYME_COMMIT = "032828cbfef50bfba41443baacc39989c203534b"
55
ENZYME_SHA256 = ""
66

77
ML_TOOLCHAIN_COMMIT = "78ef5eda03c54a912c000f1f872242d4ca6063a4"

0 commit comments

Comments
 (0)