Skip to content

Commit e1ca305

Browse files
authored
Handle mix of structured and unstructured accesses in loops (#290)
Previously, before the conversion to the PtrDialect for fallback, the structured-to-memref pass has to convert loop's iter-args with triton pointer to unranked memref. This conversion ensures all types coming out of triton-shared are mlir built-in types and therefore allows the CPU backend to correctly lower the IR to llvm. However, in reality, structured ops do not need to use the loop iter-args since ptr-analysis generates load/store ops that directly use the kernel arguments as source; this means the conversion is mostly unnecessary. With the introduction of the fallback using the PtrDialect (triton-to-ptr), we also convert the loop iter-args of triton pointer type to PtrDialect's ptr type. This conversion, along with the conversion to unranked memref above, means we will end up with `unrealized_conversion_cast` ops that convert back and forth between these two types when handling triton programs that have mixed uses of structured and unstructured accesses in loops. To solve this issue, we: - remove the conversion of loop-iter of triton ptr type to unranked memref since it is unnecessary as described as above - run remove-dead-values to remove unused loop-iter args; this pass previously could not run in presence of ops with arbitrary regions but has now been fixed in this PR: llvm/llvm-project#140793. Running remove-dead-values gives two benefits: - ability to remove all unused loop iter-arg that isn't used after ptr-analysis - make our codegen more efficient in general
1 parent 733cfef commit e1ca305

File tree

11 files changed

+240
-292
lines changed

11 files changed

+240
-292
lines changed

lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp

Lines changed: 14 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "triton/Dialect/Triton/IR/Dialect.h"
99

1010
#include "triton-shared/Conversion/StructuredToMemref/StructuredToMemref.h"
11+
#include "triton-shared/Dialect/TPtr/IR/TPtrDialect.h"
1112
#include "triton-shared/Dialect/TritonStructured/IR/TritonStructuredDialect.h"
1213
#include "triton-shared/Dialect/TritonTilingExt/IR/TritonTilingExtDialect.h"
1314

@@ -24,12 +25,7 @@
2425
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
2526
#include "mlir/Dialect/Linalg/IR/Linalg.h"
2627
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
27-
#include "mlir/Pass/PassManager.h"
2828
#include "triton/Dialect/Triton/IR/Types.h"
29-
#include "llvm/ADT/STLExtras.h"
30-
#include "llvm/Support/Casting.h"
31-
32-
#include <optional>
3329

3430
#define DEBUG_TYPE "structured-to-memref"
3531

@@ -45,67 +41,6 @@ namespace triton {
4541

4642
namespace {
4743

48-
class LoopTypeConverter : public TypeConverter {
49-
public:
50-
LoopTypeConverter(MLIRContext *context) {
51-
// The order of type conversion is important: later ones are tried earlier.
52-
addConversion([](Type type) { return type; });
53-
// addConversion([context](triton::PointerType ptrType) {
54-
// SmallVector<int64_t> strides{1};
55-
// auto layout =
56-
// StridedLayoutAttr::get(context, ShapedType::kDynamic, strides);
57-
58-
// auto elemType = ptrType.getPointeeType();
59-
// auto memrefType = MemRefType::get({1}, elemType, layout);
60-
// return memrefType;
61-
// });
62-
63-
// A tensor of pointers can be passed in as scf.for's init-args, in such
64-
// cases, we convert the type to a memref with dynamic offsets and
65-
// strides.
66-
addConversion(
67-
[context](RankedTensorType tensorType) -> std::optional<MemRefType> {
68-
if (auto ptrType = llvm::dyn_cast<triton::PointerType>(
69-
tensorType.getElementType())) {
70-
auto layout = StridedLayoutAttr::get(
71-
context, ShapedType::kDynamic,
72-
SmallVector<int64_t>(tensorType.getRank(),
73-
ShapedType::kDynamic));
74-
Type elemType = ptrType.getPointeeType();
75-
return MemRefType::get(tensorType.getShape(), elemType, layout);
76-
}
77-
78-
return std::nullopt;
79-
});
80-
81-
// Convert the current memref type to a memref type with dynamic offsets and
82-
// strides through another reinterpret_cast with the same offsets.
83-
// Canonicalization will simplify this sequence by removing the inital
84-
// reinterpret_cast.
85-
addTargetMaterialization([&](OpBuilder &builder, MemRefType memrefType,
86-
ValueRange inputs,
87-
Location loc) -> Value {
88-
auto reinterpretCast =
89-
inputs[0].getDefiningOp<memref::ReinterpretCastOp>();
90-
if (!reinterpretCast) {
91-
return builder
92-
.create<UnrealizedConversionCastOp>(loc, memrefType, inputs)
93-
.getResult(0);
94-
}
95-
return builder.create<memref::ReinterpretCastOp>(
96-
loc, memrefType, inputs[0], reinterpretCast.getMixedOffsets()[0],
97-
reinterpretCast.getMixedSizes(), reinterpretCast.getMixedStrides());
98-
});
99-
100-
addSourceMaterialization([&](OpBuilder &builder, Type resultType,
101-
ValueRange inputs,
102-
Location loc) -> Value {
103-
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
104-
.getResult(0);
105-
});
106-
}
107-
};
108-
10944
class PtrToUnrankedMemrefConverter : public TypeConverter {
11045
public:
11146
PtrToUnrankedMemrefConverter() {
@@ -120,6 +55,13 @@ class PtrToUnrankedMemrefConverter : public TypeConverter {
12055
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
12156
.getResult(0);
12257
});
58+
59+
addSourceMaterialization([&](OpBuilder &builder, Type resultType,
60+
ValueRange inputs,
61+
Location loc) -> Value {
62+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
63+
.getResult(0);
64+
});
12365
}
12466
};
12567

@@ -129,11 +71,12 @@ class StructuredToMemrefPass
12971

13072
public:
13173
void getDependentDialects(DialectRegistry &registry) const override {
132-
registry.insert<func::FuncDialect, arith::ArithDialect, math::MathDialect,
133-
linalg::LinalgDialect, affine::AffineDialect,
134-
scf::SCFDialect, tensor::TensorDialect,
135-
bufferization::BufferizationDialect, triton::TritonDialect,
136-
ttx::TritonTilingExtDialect, memref::MemRefDialect>();
74+
registry
75+
.insert<tptr::TPtrDialect, func::FuncDialect, arith::ArithDialect,
76+
math::MathDialect, linalg::LinalgDialect, affine::AffineDialect,
77+
scf::SCFDialect, tensor::TensorDialect,
78+
bufferization::BufferizationDialect, triton::TritonDialect,
79+
ttx::TritonTilingExtDialect, memref::MemRefDialect>();
13780
}
13881

13982
void runOnOperation() override {
@@ -158,11 +101,6 @@ class StructuredToMemrefPass
158101
triton::populateStructuredToMemrefConversionPatterns(patterns,
159102
typeConverter);
160103

161-
LoopTypeConverter loopTypeConverter(patterns.getContext());
162-
163-
mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
164-
loopTypeConverter, patterns, target);
165-
166104
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) {
167105
signalPassFailure();
168106
}

lib/Conversion/TritonToLinalgExperimental/TritonToLinalgExperimentalPass.cpp

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,22 +60,17 @@ class TritonToLinalgExperimentalPass
6060
pm.addPass(createTritonToUnstructuredPass());
6161
pm.addPass(createTritonArithToLinalgPass(/*tensorPtrToLinalg=*/true));
6262

63-
// TODO: structured-to-memref converts the loop iter-args to memref, while
64-
// triton-to-ptr converts the loop iter-args to ptr. These two passes might
65-
// end up conflicting with each other in cases where we have a mixed of
66-
// structured and unstructured accesses. Fortunately, the structured ops do
67-
// not need to use the loop iter-args at all (see code in PtrAnalysis.cpp),
68-
// so if we run remove-dead-values after structured-to-memref, the memref
69-
// iter-args that are used in structured loads and stores should be removed.
70-
// Running this now may be too invasive and cause many IR changes, so
71-
// leave as a TODO for now.
7263
pm.addPass(createStructuredToMemrefPass());
7364
pm.addPass(createUnstructuredToMemrefPass());
7465
pm.addPass(createTritonPtrToMemrefPass());
7566
pm.addPass(createTritonToPtrPass());
7667
pm.addPass(createReconcileUnrealizedCastsPass());
7768
pm.addPass(createReconcilePtrCastsPass());
7869

70+
// Now that remove-dead-values fully works with linalg ops, clean up the IR
71+
// again, particularly unused loop iter-args that were created
72+
// during triton-to-structured.
73+
pm.addPass(createRemoveDeadValuesPass());
7974
pm.addPass(createCSEPass());
8075
pm.addPass(createCanonicalizerPass());
8176

test/Conversion/StructuredToMemref/addptr_scalar_for.mlir

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: triton-shared-opt --split-input-file --triton-to-linalg-experimental %s | FileCheck %s
1+
// RUN: triton-shared-opt --triton-to-linalg-experimental %s | FileCheck %s
22
module {
33
tt.func @kernel (%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: i32, %arg4: i32) {
44
%0 = tt.get_program_id x : i32
@@ -47,7 +47,7 @@ module {
4747
// CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[VAR_0_]] : tensor<1024xf32>) -> tensor<1024xf32>
4848
// CHECK-DAG: [[VAR_2_:%.+]] = arith.muli [[PARAM_8_]], [[PARAM_2_]] : i32
4949
// CHECK: [[VAR_3_:%.+]] = arith.index_cast [[VAR_2_]] : i32 to index
50-
// CHECK-DAG: [[VAR_4_:%.+]]:3 = scf.for [[VAR_arg11_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg12_:%.+]] = [[VAR_2_]], [[VAR_arg13_:%.+]] = [[VAR_3_]], [[VAR_arg14_:%.+]] = [[VAR_1_]]) -> (i32, index, tensor<1024xf32>) {
50+
// CHECK-DAG: [[VAR_4_:%.+]]:2 = scf.for [[VAR_arg11_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg13_:%.+]] = [[VAR_3_]], [[VAR_arg14_:%.+]] = [[VAR_1_]]) -> (index, tensor<1024xf32>) {
5151
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_arg13_]]{{.}}, sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>>
5252
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<1024xf32>
5353
// CHECK: memref.copy [[VAR_reinterpret_cast_0_]], [[RES_]] : memref<1024xf32, strided<[1], offset: ?>> to memref<1024xf32>
@@ -62,14 +62,12 @@ module {
6262
// CHECK: [[VAR_13_1_:%.+]] = arith.addf [[IN_2_]], [[IN_3_]] : f32
6363
// CHECK: linalg.yield [[VAR_13_1_]] : f32
6464
// CHECK: } -> tensor<1024xf32>
65-
// CHECK-DAG: [[VAR_10_:%.+]] = arith.index_cast [[VAR_arg11_]] : index to i32
6665
// CHECK-DAG: [[VAR_11_:%.+]] = arith.addi [[VAR_arg13_]], [[VAR_arg11_]] : index
67-
// CHECK: [[VAR_12_:%.+]] = arith.addi [[VAR_arg12_]], [[VAR_10_]] : i32
68-
// CHECK: scf.yield [[VAR_12_]], [[VAR_11_]], [[VAR_9_]] : i32, index, tensor<1024xf32>
66+
// CHECK: scf.yield [[VAR_11_]], [[VAR_9_]] : index, tensor<1024xf32>
6967
// CHECK: }
7068
// CHECK: [[VAR_5_:%.+]] = arith.muli [[PARAM_8_]], [[PARAM_3_]] : i32
7169
// CHECK: [[VAR_6_:%.+]] = arith.index_cast [[VAR_5_]] : i32 to index
7270
// CHECK: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_6_]]{{.}}, sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1], offset: ?>>
73-
// CHECK: bufferization.materialize_in_destination [[VAR_4_]]#2 in writable [[VAR_reinterpret_cast_]] : (tensor<1024xf32>, memref<1024xf32, strided<[1], offset: ?>>) -> ()
71+
// CHECK: bufferization.materialize_in_destination [[VAR_4_]]#1 in writable [[VAR_reinterpret_cast_]] : (tensor<1024xf32>, memref<1024xf32, strided<[1], offset: ?>>) -> ()
7472
// CHECK: return
7573
// CHECK: }

test/Conversion/StructuredToMemref/addptr_scalar_for_2d.mlir

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: triton-shared-opt --split-input-file --triton-to-linalg-experimental %s | FileCheck %s
1+
// RUN: triton-shared-opt --triton-to-linalg-experimental %s | FileCheck %s
22
module {
33
tt.func @kernel (%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32, %arg3: i32, %arg4: i32) {
44
%0 = tt.get_program_id x : i32
@@ -67,7 +67,7 @@ module {
6767
// CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[CST_0_dot_000000_]] : f32) outs([[VAR_0_]] : tensor<128x128xf32>) -> tensor<128x128xf32>
6868
// CHECK-DAG: [[VAR_2_:%.+]] = arith.muli [[PARAM_8_]], [[PARAM_2_]] : i32
6969
// CHECK: [[VAR_3_:%.+]] = arith.index_cast [[VAR_2_]] : i32 to index
70-
// CHECK-DAG: [[VAR_4_:%.+]]:3 = scf.for [[VAR_arg11_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg12_:%.+]] = [[VAR_1_]], [[VAR_arg13_:%.+]] = [[VAR_2_]], [[VAR_arg14_:%.+]] = [[VAR_3_]]) -> (tensor<128x128xf32>, i32, index) {
70+
// CHECK-DAG: [[VAR_4_:%.+]]:2 = scf.for [[VAR_arg11_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg12_:%.+]] = [[VAR_1_]], [[VAR_arg14_:%.+]] = [[VAR_3_]]) -> (tensor<128x128xf32>, index) {
7171
// CHECK-DAG: [[VAR_8_:%.+]] = arith.addi [[VAR_arg14_]], [[CST_128_]] : index
7272
// CHECK-NOT: separator of consecutive DAGs
7373
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_8_]]{{.}}, sizes: [128, 128], strides: [1, 1] : memref<*xf32> to memref<128x128xf32, strided<[1, 1], offset: ?>>
@@ -84,10 +84,8 @@ module {
8484
// CHECK: [[VAR_15_1_:%.+]] = arith.addf [[IN_2_]], [[IN_3_]] : f32
8585
// CHECK: linalg.yield [[VAR_15_1_]] : f32
8686
// CHECK: } -> tensor<128x128xf32>
87-
// CHECK-DAG: [[VAR_12_:%.+]] = arith.index_cast [[VAR_arg11_]] : index to i32
8887
// CHECK-DAG: [[VAR_13_:%.+]] = arith.addi [[VAR_arg14_]], [[VAR_arg11_]] : index
89-
// CHECK: [[VAR_14_:%.+]] = arith.addi [[VAR_arg13_]], [[VAR_12_]] : i32
90-
// CHECK: scf.yield [[VAR_11_]], [[VAR_14_]], [[VAR_13_]] : tensor<128x128xf32>, i32, index
88+
// CHECK: scf.yield [[VAR_11_]], [[VAR_13_]] : tensor<128x128xf32>, index
9189
// CHECK: }
9290
// CHECK: [[VAR_5_:%.+]] = arith.muli [[PARAM_8_]], [[PARAM_3_]] : i32
9391
// CHECK: [[VAR_6_:%.+]] = arith.index_cast [[VAR_5_]] : i32 to index

test/Conversion/StructuredToMemref/nested_loops.mlir

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,4 @@ module {
182182
// CHECK-NOT: tt.addptr
183183
// CHECK-NOT: tt.load
184184
// CHECK-NOT: tt.store
185-
186-
// CHECK-COUNT-20: memref.reinterpret_cast %arg{{[0-9]+}}
187-
// CHECK-NOT: memref.reinterpret_cast %arg{{[0-9]+}}
185+
// CHECK-NOT: unrealized_conversion_cast

test/Conversion/StructuredToMemref/ridiculously_nested_loops.mlir

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,4 @@ module {
159159
// CHECK-NOT: tt.addptr
160160
// CHECK-NOT: tt.load
161161
// CHECK-NOT: tt.store
162-
163-
// CHECK-COUNT-43: memref.reinterpret_cast %arg{{[0-9]+}}
164-
// CHECK-NOT: memref.reinterpret_cast %arg{{[0-9]+}}
162+
// CHECK-NOT: unrealized_conversion_cast

test/Conversion/StructuredToMemref/scalar_store_loop_iterargs.mlir

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,14 @@ module {
2525
// CHECK-DAG: [[CST_5_:%.+]] = arith.constant 5 : i32
2626
// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32
2727
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32
28-
// CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_7_]] : i32 to index
2928
// CHECK-DAG: [[VAR_1_:%.+]] = arith.sitofp [[PARAM_7_]] : i32 to f32
3029
// CHECK-NOT: separator of consecutive DAGs
31-
// CHECK-DAG: [[VAR_2_:%.+]]:3 = scf.for [[VAR_arg16_:%.+]] = [[CST_0_]] to [[CST_5_]] step [[CST_1_]] iter_args([[VAR_arg17_:%.+]] = [[PARAM_7_]], [[VAR_arg18_:%.+]] = [[VAR_0_]], [[VAR_arg19_:%.+]] = [[VAR_0_]]) -> (i32, index, index) : i32 {
30+
// CHECK-DAG: [[VAR_2_:%.+]] = scf.for [[VAR_arg16_:%.+]] = [[CST_0_]] to [[CST_5_]] step [[CST_1_]] iter_args([[VAR_arg17_:%.+]] = [[PARAM_7_]]) -> (i32) : i32 {
3231
// CHECK-DAG: [[VAR_3_:%.+]] = arith.index_cast [[VAR_arg17_]] : i32 to index
3332
// CHECK: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: {{.}}[[VAR_3_]]{{.}}, sizes: [1], strides: [1] : memref<*xf32> to memref<1xf32, strided<[1], offset: ?>>
3433
// CHECK: memref.store [[VAR_1_]], [[VAR_reinterpret_cast_]][%[[C0]]] : memref<1xf32, strided<[1], offset: ?>>
35-
// CHECK: [[VAR_4_:%.+]] = arith.index_cast [[VAR_arg16_]] : i32 to index
36-
// CHECK-DAG: [[VAR_5_:%.+]] = arith.addi [[VAR_arg18_]], [[VAR_4_]] : index
3734
// CHECK-DAG: [[VAR_6_:%.+]] = arith.addi [[VAR_arg17_]], [[VAR_arg16_]] : i32
38-
// CHECK-DAG: [[VAR_7_:%.+]] = arith.addi [[VAR_arg19_]], [[VAR_4_]] : index
39-
// CHECK: scf.yield [[VAR_6_]], [[VAR_5_]], [[VAR_7_]] : i32, index, index
35+
// CHECK: scf.yield [[VAR_6_]] : i32
4036
// CHECK: }
4137
// CHECK: return
4238
// CHECK: }

test/Conversion/StructuredToMemref/scalar_store_nested_loop.mlir

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,26 +24,22 @@ module {
2424
// CHECK-LABEL: func.func @reduce_kernel_2d_0d
2525
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32> {tt.divisibility = 16 : i32}, [[PARAM_1_:%.+]]: i32, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32, [[PARAM_11_:%.+]]: i32, [[PARAM_12_:%.+]]: i32) {
2626
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
27-
// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index
2827
// CHECK-DAG: [[CST_1_1_:%.+]] = arith.constant 1 : i32
2928
// CHECK-DAG: [[CST_8_:%.+]] = arith.constant 8 : i32
3029
// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32
3130
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32
32-
// CHECK-DAG: [[VAR_0_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index
33-
// CHECK-NOT: separator of consecutive DAGs
34-
// CHECK-DAG: [[VAR_1_:%.+]]:2 = scf.for [[VAR_arg13_:%.+]] = [[CST_0_]] to [[CST_8_]] step [[CST_1_1_]] iter_args([[VAR_arg14_:%.+]] = [[PARAM_4_]], [[VAR_arg15_:%.+]] = [[VAR_0_]]) -> (i32, index) : i32 {
35-
// CHECK-DAG: [[VAR_2_:%.+]]:2 = scf.for [[VAR_arg16_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_1_]] iter_args([[VAR_arg17_:%.+]] = [[VAR_arg14_]], [[VAR_arg18_:%.+]] = [[VAR_arg15_]]) -> (i32, index) : i32 {
31+
// CHECK-DAG: [[VAR_1_:%.+]] = scf.for [[VAR_arg13_:%.+]] = [[CST_0_]] to [[CST_8_]] step [[CST_1_1_]] iter_args([[VAR_arg14_:%.+]] = [[PARAM_4_]]) -> (i32) : i32 {
32+
// CHECK-DAG: [[VAR_2_:%.+]] = scf.for [[VAR_arg16_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_1_]] iter_args([[VAR_arg17_:%.+]] = [[VAR_arg14_]]) -> (i32) : i32 {
3633
// CHECK-DAG: [[VAR_3_:%.+]] = arith.muli [[VAR_arg13_]], [[VAR_arg16_]] : i32
3734
// CHECK-NOT: separator of consecutive DAGs
3835
// CHECK-DAG: [[VAR_4_:%.+]] = arith.sitofp [[VAR_3_]] : i32 to f32
3936
// CHECK-DAG: [[VAR_5_:%.+]] = arith.index_cast [[VAR_arg17_]] : i32 to index
4037
// CHECK: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[VAR_5_]]{{.}}, sizes: [1], strides: [1] : memref<*xf32> to memref<1xf32, strided<[1], offset: ?>>
4138
// CHECK: memref.store [[VAR_4_]], [[VAR_reinterpret_cast_]][%[[C0]]] : memref<1xf32, strided<[1], offset: ?>>
42-
// CHECK-DAG: [[VAR_6_:%.+]] = arith.addi [[VAR_arg18_]], [[CST_1_]] : index
43-
// CHECK-DAG: [[VAR_7_:%.+]] = arith.addi [[VAR_arg17_]], [[CST_1_1_]] : i32
44-
// CHECK: scf.yield [[VAR_7_]], [[VAR_6_]] : i32, index
39+
// CHECK: [[VAR_7_:%.+]] = arith.addi [[VAR_arg17_]], [[CST_1_1_]] : i32
40+
// CHECK: scf.yield [[VAR_7_]] : i32
4541
// CHECK: }
46-
// CHECK: scf.yield [[VAR_2_]]#0, [[VAR_2_]]#1 : i32, index
42+
// CHECK: scf.yield [[VAR_2_]] : i32
4743
// CHECK: }
4844
// CHECK: return
4945
// CHECK: }

0 commit comments

Comments
 (0)