Skip to content

Commit c3cafd5

Browse files
committed
rename loop -> for_loop
1 parent 73ee4b9 commit c3cafd5

File tree

3 files changed

+12
-12
lines changed

3 files changed

+12
-12
lines changed

src/enzyme_ad/jax/Passes/LowerEnzymeProbProg.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2209,17 +2209,17 @@ struct GetFlattenedSamplesFromTraceOpConversion
22092209
}
22102210
};
22112211

2212-
struct LoopOpConversion : public OpConversionPattern<enzyme::LoopOp> {
2212+
struct ForLoopOpConversion : public OpConversionPattern<enzyme::ForLoopOp> {
22132213
using OpConversionPattern::OpConversionPattern;
22142214

22152215
std::string backend;
2216-
LoopOpConversion(std::string backend, TypeConverter &typeConverter,
2217-
MLIRContext *context, PatternBenefit benefit = 1)
2216+
ForLoopOpConversion(std::string backend, TypeConverter &typeConverter,
2217+
MLIRContext *context, PatternBenefit benefit = 1)
22182218
: OpConversionPattern(typeConverter, context, benefit), backend(backend) {
22192219
}
22202220

22212221
LogicalResult
2222-
matchAndRewrite(enzyme::LoopOp op, OpAdaptor adaptor,
2222+
matchAndRewrite(enzyme::ForLoopOp op, OpAdaptor adaptor,
22232223
ConversionPatternRewriter &rewriter) const override {
22242224
SmallVector<Value> initVals = {adaptor.getLowerBound()};
22252225
initVals.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end());
@@ -2365,14 +2365,14 @@ struct LowerProbProgToStableHLOPass
23652365
target.addIllegalOp<enzyme::CholeskySolveOp>();
23662366
target.addIllegalOp<enzyme::DotOp>();
23672367
target.addIllegalOp<enzyme::UnflattenSliceOp>();
2368-
target.addIllegalOp<enzyme::LoopOp>();
2368+
target.addIllegalOp<enzyme::ForLoopOp>();
23692369

23702370
target.addLegalOp<UnrealizedConversionCastOp>();
23712371

23722372
RewritePatternSet patterns(context);
23732373

23742374
patterns.add<RandomOpConversion, CholeskySolveOpConversion, DotOpConversion,
2375-
UnflattenSliceOpConversion, LoopOpConversion>(
2375+
UnflattenSliceOpConversion, ForLoopOpConversion>(
23762376
backend, typeConverter, context);
23772377

23782378
if (failed(applyPartialConversion(getOperation(), target,

test/lit_tests/probprog/hmc.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ module {
4242
} attributes {activity = [#enzyme<activity enzyme_active>], ret_activity = [#enzyme<activity enzyme_activenoneed>, #enzyme<activity enzyme_const>]} : (tensor<2xf64>, tensor<f64>) -> (tensor<2xui64>, tensor<2xf64>)
4343
%9 = "enzyme.broadcast"(%cst_6) <{shape = array<i64: 2>}> : (tensor<f64>) -> tensor<2xf64>
4444
%10 = "enzyme.broadcast"(%cst) <{shape = array<i64: 2>}> : (tensor<f64>) -> tensor<2xf64>
45-
%11:4 = enzyme.loop(%cst_1 : tensor<i64>) to(%cst_5 : tensor<i64>) step(%cst_0 : tensor<i64>) iter_args(%1, %result, %8#1, %8#0 : tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xui64>) -> tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xui64> {
45+
%11:4 = enzyme.for_loop(%cst_1 : tensor<i64>) to(%cst_5 : tensor<i64>) step(%cst_0 : tensor<i64>) iter_args(%1, %result, %8#1, %8#0 : tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xui64>) -> tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xui64> {
4646
^bb0(%arg3: tensor<i64>, %arg4: tensor<2xf64>, %arg5: tensor<2xf64>, %arg6: tensor<2xf64>, %arg7: tensor<2xui64>):
4747
%23 = arith.mulf %10, %arg6 : tensor<2xf64>
4848
%24 = arith.subf %arg5, %23 : tensor<2xf64>

test/lit_tests/probprog/loop.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ module {
2323
%c1 = stablehlo.constant dense<1> : tensor<i64>
2424
%init = stablehlo.constant dense<0.0> : tensor<f64>
2525

26-
%result = enzyme.loop (%c0 : tensor<i64>) to (%n : tensor<i64>) step (%c1 : tensor<i64>)
26+
%result = enzyme.for_loop (%c0 : tensor<i64>) to (%n : tensor<i64>) step (%c1 : tensor<i64>)
2727
iter_args(%init : tensor<f64>)
2828
-> tensor<f64> {
2929
^bb0(%iv: tensor<i64>, %sum_iter: tensor<f64>):
@@ -64,11 +64,11 @@ module {
6464
%c1 = stablehlo.constant dense<1> : tensor<i64>
6565
%init = stablehlo.constant dense<0.0> : tensor<f64>
6666

67-
%result = enzyme.loop (%c0 : tensor<i64>) to (%m : tensor<i64>) step (%c1 : tensor<i64>)
67+
%result = enzyme.for_loop (%c0 : tensor<i64>) to (%m : tensor<i64>) step (%c1 : tensor<i64>)
6868
iter_args(%init : tensor<f64>)
6969
-> tensor<f64> {
7070
^bb0(%i: tensor<i64>, %outer_sum_iter: tensor<f64>):
71-
%inner_result = enzyme.loop (%c0 : tensor<i64>) to (%n : tensor<i64>) step (%c1 : tensor<i64>)
71+
%inner_result = enzyme.for_loop (%c0 : tensor<i64>) to (%n : tensor<i64>) step (%c1 : tensor<i64>)
7272
iter_args(%outer_sum_iter : tensor<f64>)
7373
-> tensor<f64> {
7474
^bb1(%j: tensor<i64>, %inner_sum_iter: tensor<f64>):
@@ -108,7 +108,7 @@ module {
108108
%init_sum = stablehlo.constant dense<0.0> : tensor<f64>
109109
%init_prod = stablehlo.constant dense<1.0> : tensor<f64>
110110

111-
%sum, %prod = enzyme.loop (%c0 : tensor<i64>) to (%n : tensor<i64>) step (%c1 : tensor<i64>)
111+
%sum, %prod = enzyme.for_loop (%c0 : tensor<i64>) to (%n : tensor<i64>) step (%c1 : tensor<i64>)
112112
iter_args(%init_sum, %init_prod : tensor<f64>, tensor<f64>)
113113
-> tensor<f64>, tensor<f64> {
114114
^bb0(%iv: tensor<i64>, %s_iter: tensor<f64>, %p_iter: tensor<f64>):
@@ -173,7 +173,7 @@ module {
173173
%init_sum = stablehlo.constant dense<0.0> : tensor<f64>
174174
%init_trace = enzyme.initTrace : !enzyme.Trace
175175

176-
%sum, %trace = enzyme.loop (%c0 : tensor<i64>) to (%n : tensor<i64>) step (%c1 : tensor<i64>)
176+
%sum, %trace = enzyme.for_loop (%c0 : tensor<i64>) to (%n : tensor<i64>) step (%c1 : tensor<i64>)
177177
iter_args(%init_sum, %init_trace : tensor<f64>, !enzyme.Trace)
178178
-> tensor<f64>, !enzyme.Trace {
179179
^bb0(%iv: tensor<i64>, %s_iter: tensor<f64>, %t_iter: !enzyme.Trace):

0 commit comments

Comments
 (0)