Skip to content

Commit 0bb88e0

Browse files
committed
rename loop -> for_loop
1 parent c051c60 commit 0bb88e0

File tree

3 files changed

+12
-12
lines changed

3 files changed

+12
-12
lines changed

src/enzyme_ad/jax/Passes/LowerEnzymeProbProg.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2239,17 +2239,17 @@ struct GetFlattenedSamplesFromTraceOpConversion
22392239
}
22402240
};
22412241

2242-
struct LoopOpConversion : public OpConversionPattern<enzyme::LoopOp> {
2242+
struct ForLoopOpConversion : public OpConversionPattern<enzyme::ForLoopOp> {
22432243
using OpConversionPattern::OpConversionPattern;
22442244

22452245
std::string backend;
2246-
LoopOpConversion(std::string backend, TypeConverter &typeConverter,
2247-
MLIRContext *context, PatternBenefit benefit = 1)
2246+
ForLoopOpConversion(std::string backend, TypeConverter &typeConverter,
2247+
MLIRContext *context, PatternBenefit benefit = 1)
22482248
: OpConversionPattern(typeConverter, context, benefit), backend(backend) {
22492249
}
22502250

22512251
LogicalResult
2252-
matchAndRewrite(enzyme::LoopOp op, OpAdaptor adaptor,
2252+
matchAndRewrite(enzyme::ForLoopOp op, OpAdaptor adaptor,
22532253
ConversionPatternRewriter &rewriter) const override {
22542254
SmallVector<Value> initVals = {adaptor.getLowerBound()};
22552255
initVals.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end());
@@ -2395,14 +2395,14 @@ struct LowerProbProgToStableHLOPass
23952395
target.addIllegalOp<enzyme::CholeskySolveOp>();
23962396
target.addIllegalOp<enzyme::DotOp>();
23972397
target.addIllegalOp<enzyme::UnflattenSliceOp>();
2398-
target.addIllegalOp<enzyme::LoopOp>();
2398+
target.addIllegalOp<enzyme::ForLoopOp>();
23992399

24002400
target.addLegalOp<UnrealizedConversionCastOp>();
24012401

24022402
RewritePatternSet patterns(context);
24032403

24042404
patterns.add<RandomOpConversion, CholeskySolveOpConversion, DotOpConversion,
2405-
UnflattenSliceOpConversion, LoopOpConversion>(
2405+
UnflattenSliceOpConversion, ForLoopOpConversion>(
24062406
backend, typeConverter, context);
24072407

24082408
if (failed(applyPartialConversion(getOperation(), target,

test/lit_tests/probprog/hmc.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ module {
4242
} attributes {activity = [#enzyme<activity enzyme_active>], ret_activity = [#enzyme<activity enzyme_activenoneed>, #enzyme<activity enzyme_const>]} : (tensor<2xf64>, tensor<f64>) -> (tensor<2xui64>, tensor<2xf64>)
4343
%9 = "enzyme.broadcast"(%cst_6) <{shape = array<i64: 2>}> : (tensor<f64>) -> tensor<2xf64>
4444
%10 = "enzyme.broadcast"(%cst) <{shape = array<i64: 2>}> : (tensor<f64>) -> tensor<2xf64>
45-
%11:4 = enzyme.loop(%cst_1 : tensor<i64>) to(%cst_5 : tensor<i64>) step(%cst_0 : tensor<i64>) iter_args(%1, %result, %8#1, %8#0 : tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xui64>) -> tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xui64> {
45+
%11:4 = enzyme.for_loop(%cst_1 : tensor<i64>) to(%cst_5 : tensor<i64>) step(%cst_0 : tensor<i64>) iter_args(%1, %result, %8#1, %8#0 : tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xui64>) -> tensor<2xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xui64> {
4646
^bb0(%arg3: tensor<i64>, %arg4: tensor<2xf64>, %arg5: tensor<2xf64>, %arg6: tensor<2xf64>, %arg7: tensor<2xui64>):
4747
%23 = arith.mulf %10, %arg6 : tensor<2xf64>
4848
%24 = arith.subf %arg5, %23 : tensor<2xf64>

test/lit_tests/probprog/loop.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ module {
2323
%c1 = stablehlo.constant dense<1> : tensor<i64>
2424
%init = stablehlo.constant dense<0.0> : tensor<f64>
2525

26-
%result = enzyme.loop (%c0 : tensor<i64>) to (%n : tensor<i64>) step (%c1 : tensor<i64>)
26+
%result = enzyme.for_loop (%c0 : tensor<i64>) to (%n : tensor<i64>) step (%c1 : tensor<i64>)
2727
iter_args(%init : tensor<f64>)
2828
-> tensor<f64> {
2929
^bb0(%iv: tensor<i64>, %sum_iter: tensor<f64>):
@@ -64,11 +64,11 @@ module {
6464
%c1 = stablehlo.constant dense<1> : tensor<i64>
6565
%init = stablehlo.constant dense<0.0> : tensor<f64>
6666

67-
%result = enzyme.loop (%c0 : tensor<i64>) to (%m : tensor<i64>) step (%c1 : tensor<i64>)
67+
%result = enzyme.for_loop (%c0 : tensor<i64>) to (%m : tensor<i64>) step (%c1 : tensor<i64>)
6868
iter_args(%init : tensor<f64>)
6969
-> tensor<f64> {
7070
^bb0(%i: tensor<i64>, %outer_sum_iter: tensor<f64>):
71-
%inner_result = enzyme.loop (%c0 : tensor<i64>) to (%n : tensor<i64>) step (%c1 : tensor<i64>)
71+
%inner_result = enzyme.for_loop (%c0 : tensor<i64>) to (%n : tensor<i64>) step (%c1 : tensor<i64>)
7272
iter_args(%outer_sum_iter : tensor<f64>)
7373
-> tensor<f64> {
7474
^bb1(%j: tensor<i64>, %inner_sum_iter: tensor<f64>):
@@ -108,7 +108,7 @@ module {
108108
%init_sum = stablehlo.constant dense<0.0> : tensor<f64>
109109
%init_prod = stablehlo.constant dense<1.0> : tensor<f64>
110110

111-
%sum, %prod = enzyme.loop (%c0 : tensor<i64>) to (%n : tensor<i64>) step (%c1 : tensor<i64>)
111+
%sum, %prod = enzyme.for_loop (%c0 : tensor<i64>) to (%n : tensor<i64>) step (%c1 : tensor<i64>)
112112
iter_args(%init_sum, %init_prod : tensor<f64>, tensor<f64>)
113113
-> tensor<f64>, tensor<f64> {
114114
^bb0(%iv: tensor<i64>, %s_iter: tensor<f64>, %p_iter: tensor<f64>):
@@ -173,7 +173,7 @@ module {
173173
%init_sum = stablehlo.constant dense<0.0> : tensor<f64>
174174
%init_trace = enzyme.initTrace : !enzyme.Trace
175175

176-
%sum, %trace = enzyme.loop (%c0 : tensor<i64>) to (%n : tensor<i64>) step (%c1 : tensor<i64>)
176+
%sum, %trace = enzyme.for_loop (%c0 : tensor<i64>) to (%n : tensor<i64>) step (%c1 : tensor<i64>)
177177
iter_args(%init_sum, %init_trace : tensor<f64>, !enzyme.Trace)
178178
-> tensor<f64>, !enzyme.Trace {
179179
^bb0(%iv: tensor<i64>, %s_iter: tensor<f64>, %t_iter: !enzyme.Trace):

0 commit comments

Comments
 (0)