@@ -140,48 +140,49 @@ module {
140140// CPU-NEXT: %28 = stablehlo.broadcast_in_dim %cst_6, dims = [] : (tensor<f64>) -> tensor<2xf64>
141141// CPU-NEXT: %29:5 = stablehlo.while(%iterArg = %c_8, %iterArg_17 = %1, %iterArg_18 = %15, %iterArg_19 = %26, %iterArg_20 = %output_state) : tensor<i64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xui64>
142142// CPU-NEXT: cond {
143- // CPU-NEXT: %56 = stablehlo.compare LT, %iterArg, %c_12 : (tensor<i64>, tensor<i64>) -> tensor<i1>
144- // CPU-NEXT: stablehlo.return %56 : tensor<i1>
143+ // CPU-NEXT: %57 = stablehlo.compare LT, %iterArg, %c_12 : (tensor<i64>, tensor<i64>) -> tensor<i1>
144+ // CPU-NEXT: stablehlo.return %57 : tensor<i1>
145145// CPU-NEXT: } do {
146- // CPU-NEXT: %56 = stablehlo.multiply %28, %iterArg_19 : tensor<2xf64>
147- // CPU-NEXT: %57 = stablehlo.subtract %iterArg_18, %56 : tensor<2xf64>
148- // CPU-NEXT: %58 = stablehlo.reshape %57 : (tensor<2xf64>) -> tensor<2x1xf64>
149- // CPU-NEXT: %59 = "stablehlo.triangular_solve"(%12, %58 ) <{left_side = true, lower = true, transpose_a = #stablehlo<transpose NO_TRANSPOSE>, unit_diagonal = false}> : (tensor<2x2xf64>, tensor<2x1xf64>) -> tensor<2x1xf64>
150- // CPU-NEXT: %60 = "stablehlo.triangular_solve"(%12, %59 ) <{left_side = true, lower = true, transpose_a = #stablehlo<transpose TRANSPOSE>, unit_diagonal = false}> : (tensor<2x2xf64>, tensor<2x1xf64>) -> tensor<2x1xf64>
151- // CPU-NEXT: %61 = stablehlo.reshape %60 : (tensor<2x1xf64>) -> tensor<2xf64>
152- // CPU-NEXT: %62 = stablehlo.multiply %27, %61 : tensor<2xf64>
153- // CPU-NEXT: %63 = stablehlo.add %iterArg_17, %62 : tensor<2xf64>
154- // CPU-NEXT: %64 = stablehlo.multiply %28, %26 : tensor<2xf64>
155- // CPU-NEXT: %65 = stablehlo.subtract %57 , %64 : tensor<2xf64>
156- // CPU-NEXT: %66 = stablehlo.add %iterArg, %c_7 : tensor<i64>
157- // CPU-NEXT: stablehlo.return %66 , %63 , %65 , %26, %iterArg_20 : tensor<i64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xui64>
146+ // CPU-NEXT: %57 = stablehlo.multiply %28, %iterArg_19 : tensor<2xf64>
147+ // CPU-NEXT: %58 = stablehlo.subtract %iterArg_18, %57 : tensor<2xf64>
148+ // CPU-NEXT: %59 = stablehlo.reshape %58 : (tensor<2xf64>) -> tensor<2x1xf64>
149+ // CPU-NEXT: %60 = "stablehlo.triangular_solve"(%12, %59 ) <{left_side = true, lower = true, transpose_a = #stablehlo<transpose NO_TRANSPOSE>, unit_diagonal = false}> : (tensor<2x2xf64>, tensor<2x1xf64>) -> tensor<2x1xf64>
150+ // CPU-NEXT: %61 = "stablehlo.triangular_solve"(%12, %60 ) <{left_side = true, lower = true, transpose_a = #stablehlo<transpose TRANSPOSE>, unit_diagonal = false}> : (tensor<2x2xf64>, tensor<2x1xf64>) -> tensor<2x1xf64>
151+ // CPU-NEXT: %62 = stablehlo.reshape %61 : (tensor<2x1xf64>) -> tensor<2xf64>
152+ // CPU-NEXT: %63 = stablehlo.multiply %27, %62 : tensor<2xf64>
153+ // CPU-NEXT: %64 = stablehlo.add %iterArg_17, %63 : tensor<2xf64>
154+ // CPU-NEXT: %65 = stablehlo.multiply %28, %26 : tensor<2xf64>
155+ // CPU-NEXT: %66 = stablehlo.subtract %58 , %65 : tensor<2xf64>
156+ // CPU-NEXT: %67 = stablehlo.add %iterArg, %c_7 : tensor<i64>
157+ // CPU-NEXT: stablehlo.return %67 , %64 , %66 , %26, %iterArg_20 : tensor<i64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xui64>
158158// CPU-NEXT: }
159- // CPU-NEXT: %30 = stablehlo.slice %29#1 [0:1] : (tensor<2xf64>) -> tensor<1xf64>
160- // CPU-NEXT: %31 = stablehlo.reshape %30 : (tensor<1xf64>) -> tensor<f64>
161- // CPU-NEXT: %32 = enzyme.addSampleToTrace(%31 : tensor<f64>) into %0 {symbol = #enzyme.symbol<1>}
162- // CPU-NEXT: %33 = stablehlo.slice %29#1 [1:2] : (tensor<2xf64>) -> tensor<1xf64>
163- // CPU-NEXT: %34 = stablehlo.reshape %33 : (tensor<1xf64>) -> tensor<f64>
164- // CPU-NEXT: %35 = stablehlo.add %cst_11, %cst_11 : tensor<f64>
165- // CPU-NEXT: %36 = enzyme.addSampleToTrace(%34 : tensor<f64>) into %32 {symbol = #enzyme.symbol<2>}
166- // CPU-NEXT: %37 = enzyme.addWeightToTrace(%35 : tensor<f64>) into %36
167- // CPU-NEXT: %38 = enzyme.addRetvalToTrace(%34 : tensor<f64>) into %37
168- // CPU-NEXT: %39 = stablehlo.negate %35 : tensor<f64>
169- // CPU-NEXT: %40 = stablehlo.reshape %29#2 : (tensor<2xf64>) -> tensor<2x1xf64>
170- // CPU-NEXT: %41 = "stablehlo.triangular_solve"(%12, %40) <{left_side = true, lower = true, transpose_a = #stablehlo<transpose NO_TRANSPOSE>, unit_diagonal = false}> : (tensor<2x2xf64>, tensor<2x1xf64>) -> tensor<2x1xf64>
171- // CPU-NEXT: %42 = "stablehlo.triangular_solve"(%12, %41) <{left_side = true, lower = true, transpose_a = #stablehlo<transpose TRANSPOSE>, unit_diagonal = false}> : (tensor<2x2xf64>, tensor<2x1xf64>) -> tensor<2x1xf64>
172- // CPU-NEXT: %43 = stablehlo.reshape %42 : (tensor<2x1xf64>) -> tensor<2xf64>
173- // CPU-NEXT: %44 = stablehlo.dot_general %29#2, %43, contracting_dims = [0] x [0] : (tensor<2xf64>, tensor<2xf64>) -> tensor<f64>
174- // CPU-NEXT: %45 = stablehlo.multiply %44, %cst_9 : tensor<f64>
175- // CPU-NEXT: %46 = stablehlo.add %39, %45 : tensor<f64>
176- // CPU-NEXT: %47 = stablehlo.subtract %22, %46 : tensor<f64>
177- // CPU-NEXT: %48 = stablehlo.exponential %47 : tensor<f64>
178- // CPU-NEXT: %49 = stablehlo.minimum %48, %cst_10 : tensor<f64>
159+ // CPU-NEXT: %30 = enzyme.initTrace : !enzyme.Trace
160+ // CPU-NEXT: %31 = stablehlo.slice %29#1 [0:1] : (tensor<2xf64>) -> tensor<1xf64>
161+ // CPU-NEXT: %32 = stablehlo.reshape %31 : (tensor<1xf64>) -> tensor<f64>
162+ // CPU-NEXT: %33 = enzyme.addSampleToTrace(%32 : tensor<f64>) into %30 {symbol = #enzyme.symbol<1>}
163+ // CPU-NEXT: %34 = stablehlo.slice %29#1 [1:2] : (tensor<2xf64>) -> tensor<1xf64>
164+ // CPU-NEXT: %35 = stablehlo.reshape %34 : (tensor<1xf64>) -> tensor<f64>
165+ // CPU-NEXT: %36 = stablehlo.add %cst_11, %cst_11 : tensor<f64>
166+ // CPU-NEXT: %37 = enzyme.addSampleToTrace(%35 : tensor<f64>) into %33 {symbol = #enzyme.symbol<2>}
167+ // CPU-NEXT: %38 = enzyme.addWeightToTrace(%36 : tensor<f64>) into %37
168+ // CPU-NEXT: %39 = enzyme.addRetvalToTrace(%35 : tensor<f64>) into %38
169+ // CPU-NEXT: %40 = stablehlo.negate %36 : tensor<f64>
170+ // CPU-NEXT: %41 = stablehlo.reshape %29#2 : (tensor<2xf64>) -> tensor<2x1xf64>
171+ // CPU-NEXT: %42 = "stablehlo.triangular_solve"(%12, %41) <{left_side = true, lower = true, transpose_a = #stablehlo<transpose NO_TRANSPOSE>, unit_diagonal = false}> : (tensor<2x2xf64>, tensor<2x1xf64>) -> tensor<2x1xf64>
172+ // CPU-NEXT: %43 = "stablehlo.triangular_solve"(%12, %42) <{left_side = true, lower = true, transpose_a = #stablehlo<transpose TRANSPOSE>, unit_diagonal = false}> : (tensor<2x2xf64>, tensor<2x1xf64>) -> tensor<2x1xf64>
173+ // CPU-NEXT: %44 = stablehlo.reshape %43 : (tensor<2x1xf64>) -> tensor<2xf64>
174+ // CPU-NEXT: %45 = stablehlo.dot_general %29#2, %44, contracting_dims = [0] x [0] : (tensor<2xf64>, tensor<2xf64>) -> tensor<f64>
175+ // CPU-NEXT: %46 = stablehlo.multiply %45, %cst_9 : tensor<f64>
176+ // CPU-NEXT: %47 = stablehlo.add %40, %46 : tensor<f64>
177+ // CPU-NEXT: %48 = stablehlo.subtract %22, %47 : tensor<f64>
178+ // CPU-NEXT: %49 = stablehlo.exponential %48 : tensor<f64>
179+ // CPU-NEXT: %50 = stablehlo.minimum %49, %cst_10 : tensor<f64>
179180// CPU-NEXT: %output_state_15, %output_16 = stablehlo.rng_bit_generator %29#4, algorithm = DEFAULT : (tensor<2xui64>) -> (tensor<2xui64>, tensor<ui64>)
180- // CPU-NEXT: %50 = stablehlo.shift_right_logical %output_16, %c_0 : tensor<ui64>
181- // CPU-NEXT: %51 = stablehlo.or %50 , %c : tensor<ui64>
182- // CPU-NEXT: %52 = stablehlo.bitcast_convert %51 : (tensor<ui64>) -> tensor<f64>
183- // CPU-NEXT: %53 = stablehlo.subtract %52 , %cst_10 : tensor<f64>
184- // CPU-NEXT: %54 = stablehlo.compare LT, %53 , %49 , FLOAT : (tensor<f64>, tensor<f64>) -> tensor<i1>
185- // CPU-NEXT: %55 = enzyme.selectTrace %54 , %38 , %0 : tensor<i1>
186- // CPU-NEXT: return %55 , %54 , %output_state_15 : !enzyme.Trace, tensor<i1>, tensor<2xui64>
181+ // CPU-NEXT: %51 = stablehlo.shift_right_logical %output_16, %c_0 : tensor<ui64>
182+ // CPU-NEXT: %52 = stablehlo.or %51 , %c : tensor<ui64>
183+ // CPU-NEXT: %53 = stablehlo.bitcast_convert %52 : (tensor<ui64>) -> tensor<f64>
184+ // CPU-NEXT: %54 = stablehlo.subtract %53 , %cst_10 : tensor<f64>
185+ // CPU-NEXT: %55 = stablehlo.compare LT, %54 , %50 , FLOAT : (tensor<f64>, tensor<f64>) -> tensor<i1>
186+ // CPU-NEXT: %56 = enzyme.selectTrace %55 , %39 , %0 : tensor<i1>
187+ // CPU-NEXT: return %56 , %55 , %output_state_15 : !enzyme.Trace, tensor<i1>, tensor<2xui64>
187188// CPU-NEXT: }
0 commit comments