diff --git a/lib/Conversion/UnstructuredToMemref/UnstructuredToMemrefPass.cpp b/lib/Conversion/UnstructuredToMemref/UnstructuredToMemrefPass.cpp index 2fed7843..22e6ffc8 100644 --- a/lib/Conversion/UnstructuredToMemref/UnstructuredToMemrefPass.cpp +++ b/lib/Conversion/UnstructuredToMemref/UnstructuredToMemrefPass.cpp @@ -52,8 +52,7 @@ class PtrToUnrankedMemrefConverter : public TypeConverter { }); addTargetMaterialization([&](OpBuilder &builder, UnrankedMemRefType resultType, - ValueRange inputs, - Location loc) -> Value { + ValueRange inputs, Location loc) -> Value { return builder.create(loc, resultType, inputs) .getResult(0); }); @@ -158,7 +157,7 @@ struct ScalarStoreConverter : public OpConversionPattern { } }; -// Lowering an unstructured load op (gather) into a linalg.generic op +// Lowering an unstructured load op (gather) into a linalg.generic op. struct GatherConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -171,28 +170,29 @@ struct GatherConverter : public OpConversionPattern { LogicalResult matchAndRewrite(tts::GatherOp gatherOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = gatherOp->getLoc(); auto ptr = adaptor.getPtr(); auto offsetTensor = adaptor.getOffset(); auto offsetType = dyn_cast(offsetTensor.getType()); - // This must be a scalar load, skip processing + // This must be a scalar load, skip processing. if (!offsetType) { return failure(); } - auto loadResultType = + auto resultType = dyn_cast(gatherOp.getResult().getType()); // Treat the base pointer (memref) as 1D because the offsets are all // relative to a single base pointer (already collapsed). - auto baseMemref = rewriter.create( - loc, - MemRefType::get({ShapedType::kDynamic}, - loadResultType.getElementType()), - ptr); + auto baseMemref = rewriter + .create( + loc, + MemRefType::get({ShapedType::kDynamic}, + resultType.getElementType()), + ptr) + .getResult(); auto baseTensor = rewriter @@ -200,89 +200,79 @@ struct GatherConverter : public OpConversionPattern { loc, RankedTensorType::get( SmallVector(1, ShapedType::kDynamic), - loadResultType.getElementType()), + resultType.getElementType()), baseMemref, true /* restrict */, false /* writable */) .getResult(); // The linalg.generic op should have the following inputs: - // - the offset tensor - // - an optional mask tensor if the load op contains mask + // - the offset tensor. + // - an optional mask tensor if the gather op contains mask. SmallVector inputs{offsetTensor}; if (gatherOp.getMask()) { inputs.push_back(gatherOp.getMask()); } - auto emptyTensor = - rewriter - .create(loc, loadResultType.getShape(), - loadResultType.getElementType()) - .getResult(); + auto emptyTensor = rewriter + .create(loc, resultType.getShape(), + resultType.getElementType()) + .getResult(); - // Affine maps for the inputs and output - // If no mask is used, 2 affine maps are generated; one for the input offset - // tensor, the other for the output tensor. - // If mask is used, the first 2 maps are for the offset and mask tensors - // while the last map is for the output tensor. + // Affine maps for the inputs and one additional output. SmallVector affineMaps( - gatherOp.getMask() ? 3 : 2, - rewriter.getMultiDimIdentityMap(loadResultType.getRank())); + inputs.size() + 1, + rewriter.getMultiDimIdentityMap(resultType.getRank())); + + // All iterator types are parallel. + SmallVector iteratorTypes( + resultType.getRank(), utils::IteratorType::parallel); auto genericOp = rewriter.create( - loc, SmallVector({loadResultType}), inputs, - ValueRange{emptyTensor}, affineMaps, - SmallVector(loadResultType.getRank(), - utils::IteratorType::parallel), - [&](OpBuilder &b, Location loc, ValueRange args) { - auto getValueAtIndex = [baseTensor](Value indexValue, Location loc, - OpBuilder &b) -> Value { + loc, TypeRange{resultType}, inputs, ValueRange{emptyTensor}, affineMaps, + iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { + auto getValueAtIndex = [baseTensor](OpBuilder &b, Location loc, + Value index) -> Value { Value index0 = - b.create(loc, b.getIndexType(), indexValue); + b.create(loc, b.getIndexType(), index); return b.create(loc, baseTensor, ValueRange{index0}); }; + auto offset = args[0]; + if (!gatherOp.getMask()) { // If there is no mask, simply extract the current element from the // base tensor and use it as the yield value. - auto loadValue = getValueAtIndex(args[0], loc, rewriter); - rewriter.create(loc, loadValue); + auto loadValue = getValueAtIndex(b, loc, offset); + b.create(loc, loadValue); } else { // If the mask value is truthy, the current element is loaded from // the base tensor using its offset. Otherwise, if `other` is // present, yield `other`. If `other` is not present, a default // value of 0 is used. auto mask = args[1]; - auto ifOp = rewriter.create( + auto ifOp = b.create( loc, mask, [&](OpBuilder &b, Location loc) { - // Truthy case, load from the index - auto loadValue = getValueAtIndex(args[0], loc, b); - b.create(loc, loadValue); + // Truthy case, load from the index. + auto value = getValueAtIndex(b, loc, offset); + b.create(loc, value); }, [&](OpBuilder &b, Location loc) { - // Falsy case, yield `other` or 0 as the default value + // Falsy case, yield `other` or 0 as the default value. if (gatherOp.getOther()) { b.create(loc, gatherOp.getOther()); } else { - auto elemType = baseTensor.getType().getElementType(); - Value extract; - if (isa(elemType)) { - extract = rewriter.create( - loc, b.getIntegerAttr(elemType, 0)); - } else if (isa(elemType)) { - extract = rewriter.create( - loc, b.getFloatAttr(elemType, 0)); - } else { - elemType.dump(); - llvm_unreachable("unexpected type"); - } + auto elemType = resultType.getElementType(); + auto zeroAttr = b.getZeroAttr(elemType); + assert(zeroAttr && "unexpected element type"); + Value extract = b.create(loc, zeroAttr); b.create(loc, extract); } }); - rewriter.create(loc, ifOp->getResult(0)); + b.create(loc, ifOp->getResult(0)); } }); @@ -292,7 +282,7 @@ struct GatherConverter : public OpConversionPattern { } }; -// Lowering an unstructured store op (scatter) into an affine loop nest +// Lowering an unstructured store op (scatter) into a linalg.generic op. struct ScatterConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -309,57 +299,81 @@ struct ScatterConverter : public OpConversionPattern { auto ptr = adaptor.getPtr(); auto offsetTensor = adaptor.getOffset(); + auto valueTensor = adaptor.getValue(); auto offsetType = dyn_cast(offsetTensor.getType()); - // This must be a scalar store, skip processing + // This must be a scalar store, skip processing. if (!offsetType) { return failure(); } - auto resultType = - dyn_cast(scatterOp.getValue().getType()); + auto valueType = dyn_cast(scatterOp.getValue().getType()); - auto storeMemref = rewriter.create( - loc, - MemRefType::get({ShapedType::kDynamic}, resultType.getElementType()), - ptr); - - auto ip = rewriter.saveInsertionPoint(); - - SmallVector ivs; - for (auto dim : resultType.getShape()) { - auto ub = - rewriter.create(loc, rewriter.getIndexAttr(dim)); - auto forOp = rewriter.create(loc, 0, dim); - ivs.push_back(forOp.getInductionVar()); - rewriter.setInsertionPointToStart(forOp.getBody()); - } + // Treat the base pointer (memref) as 1D because the offsets are all + // relative to a single base pointer (already collapsed). + auto baseMemref = + rewriter + .create(loc, + MemRefType::get({ShapedType::kDynamic}, + valueType.getElementType()), + ptr) + .getResult(); + + // The linalg.generic op should have the following inputs: + // - the offset tensor. + // - the value tensor. + // - an optional mask tensor if the scatter op contains mask. + SmallVector inputs{offsetTensor, valueTensor}; if (scatterOp.getMask()) { - // Mask case, only store the value if the mask value at `ivs` is truthy - auto maskValue = - rewriter.create(loc, scatterOp.getMask(), ivs); + inputs.push_back(scatterOp.getMask()); + } - auto ifOp = rewriter.create(loc, maskValue, - false /* withElseRegion */); + // Affine maps for the inputs. + SmallVector affineMaps( + inputs.size(), rewriter.getMultiDimIdentityMap(valueType.getRank())); - rewriter.setInsertionPointToStart( - &ifOp.getThenRegion().getBlocks().front()); - } + // All iterator types are parallel. + SmallVector iteratorTypes( + valueType.getRank(), utils::IteratorType::parallel); - // Generate ops to store the value at each index. Note that with masking, - // these ops are created in the `if` block generated above. - auto offsetValue = - rewriter.create(loc, offsetTensor, ivs); - auto storeValue = - rewriter.create(loc, scatterOp.getValue(), ivs); - Value storeIndex = rewriter.create( - loc, rewriter.getIndexType(), offsetValue); - rewriter.create(loc, storeValue, storeMemref, storeIndex); + rewriter.setInsertionPoint(scatterOp); + + auto genericOp = rewriter.create( + loc, TypeRange{}, inputs, ValueRange{}, affineMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + auto storeValueAtIndex = [baseMemref](OpBuilder &b, Location loc, + Value index, Value value) { + Value index0 = + b.create(loc, b.getIndexType(), index); + + b.create(loc, value, baseMemref, + ValueRange{index0}); + }; + + auto offset = args[0]; + auto value = args[1]; + + if (!scatterOp.getMask()) { + // If there is no mask, simply insert the current value to the + // base memref using its offset. + storeValueAtIndex(b, loc, offset, value); + } else { + // If the mask value is truthy, insert the current value to the + // the base memref using its offset. Otherwise, noop. + auto mask = args[2]; + auto ifOp = + b.create(loc, mask, [&](OpBuilder &b, Location loc) { + storeValueAtIndex(b, loc, offset, value); + b.create(loc); + }); + } + + b.create(loc); + }); - // Finalize rewriter.eraseOp(scatterOp); - rewriter.restoreInsertionPoint(ip); + return success(); } }; diff --git a/test/Conversion/StructuredToMemref/convert_1d_elemwise_arith_binary.mlir b/test/Conversion/StructuredToMemref/convert_1d_elemwise_arith_binary.mlir index 1d96e320..dedd80ef 100644 --- a/test/Conversion/StructuredToMemref/convert_1d_elemwise_arith_binary.mlir +++ b/test/Conversion/StructuredToMemref/convert_1d_elemwise_arith_binary.mlir @@ -32,7 +32,9 @@ module { // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> // CHECK-LABEL: func.func @kernel // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: memref<*xf32>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32) { -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[VAR_empty_offsets_:%.+]] = tensor.empty() : tensor<1024xi32> +// CHECK-DAG: [[VAR_zero_offsets_:%.+]] = linalg.fill ins([[CST_0_]] : i32) outs([[VAR_empty_offsets_]] : tensor<1024xi32>) -> tensor<1024xi32> // CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1]>> // CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1]>> // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<1024xf32> @@ -73,9 +75,11 @@ module { // CHECK: linalg.yield [[VAR_9_5_]] : f32 // CHECK: } -> tensor<1024xf32> // CHECK: [[VAR_cast_:%.+]] = memref.cast [[PARAM_2_]] : memref<*xf32> to memref -// CHECK: affine.for [[I_0_:%.+]] = 0 to 1024 { -// CHECK: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_8_]]{{.}}[[I_0_]]{{.}} : tensor<1024xf32> -// CHECK: memref.store [[VAR_extracted_]], [[VAR_cast_]]{{.}}[[CST_0_]]{{.}} : memref +// CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_zero_offsets_]], [[VAR_8_]] : tensor<1024xi32>, tensor<1024xf32>) { +// CHECK: ^bb0([[IN_18_:%.+]]: i32, [[IN_19_:%.+]]: f32): +// CHECK: [[VAR_10_:%.+]] = arith.index_cast [[IN_18_]] : i32 to index +// CHECK: memref.store [[IN_19_]], [[VAR_cast_]]{{.}}[[VAR_10_]]{{.}} : memref +// CHECK: linalg.yield // CHECK: } // CHECK: return // CHECK: } diff --git a/test/Conversion/StructuredToMemref/convert_1d_elemwise_arith_ternary.mlir b/test/Conversion/StructuredToMemref/convert_1d_elemwise_arith_ternary.mlir index 92393d0f..c22eda2c 100644 --- a/test/Conversion/StructuredToMemref/convert_1d_elemwise_arith_ternary.mlir +++ b/test/Conversion/StructuredToMemref/convert_1d_elemwise_arith_ternary.mlir @@ -31,7 +31,9 @@ module { // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> // CHECK-LABEL: func.func @kernel // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xi1>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: memref<*xf32>, [[PARAM_3_:%.+]]: memref<*xf32>, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32) { -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[VAR_empty_offsets_:%.+]] = tensor.empty() : tensor<1024xi32> +// CHECK-DAG: [[VAR_zero_offsets_:%.+]] = linalg.fill ins([[CST_0_]] : i32) outs([[VAR_empty_offsets_]] : tensor<1024xi32>) -> tensor<1024xi32> // CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [1024], strides: [1] : memref<*xi1> to memref<1024xi1, strided<[1]>> // CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1]>> // CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: [0], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1]>> @@ -50,9 +52,11 @@ module { // CHECK: linalg.yield [[VAR_4_]] : f32 // CHECK: } -> tensor<1024xf32> // CHECK: [[VAR_cast_:%.+]] = memref.cast [[PARAM_3_]] : memref<*xf32> to memref -// CHECK: affine.for [[I_0_:%.+]] = 0 to 1024 { -// CHECK: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_3_]]{{.}}[[I_0_]]{{.}} : tensor<1024xf32> -// CHECK: memref.store [[VAR_extracted_]], [[VAR_cast_]]{{.}}[[CST_0_]]{{.}} : memref +// CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_zero_offsets_]], [[VAR_3_]] : tensor<1024xi32>, tensor<1024xf32>) { +// CHECK: ^bb0([[IN_3_:%.+]]: i32, [[IN_4_:%.+]]: f32): +// CHECK: [[VAR_5_:%.+]] = arith.index_cast [[IN_3_]] : i32 to index +// CHECK: memref.store [[IN_4_]], [[VAR_cast_]]{{.}}[[VAR_5_]]{{.}} : memref +// CHECK: linalg.yield // CHECK: } // CHECK: return // CHECK: } diff --git a/test/Conversion/StructuredToMemref/convert_1d_elemwise_arith_unary.mlir b/test/Conversion/StructuredToMemref/convert_1d_elemwise_arith_unary.mlir index 86a6c0e1..fb74a44d 100644 --- a/test/Conversion/StructuredToMemref/convert_1d_elemwise_arith_unary.mlir +++ b/test/Conversion/StructuredToMemref/convert_1d_elemwise_arith_unary.mlir @@ -47,7 +47,9 @@ module { // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> // CHECK-LABEL: func.func @kernel // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xi32>, [[PARAM_2_:%.+]]: memref<*xf16>, [[PARAM_3_:%.+]]: memref<*xbf16>, [[PARAM_4_:%.+]]: memref<*xf32>, [[PARAM_5_:%.+]]: memref<*xf32>, [[PARAM_6_:%.+]]: memref<*xf32>, [[PARAM_7_:%.+]]: memref<*xf32>, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32, [[PARAM_11_:%.+]]: i32, [[PARAM_12_:%.+]]: i32, [[PARAM_13_:%.+]]: i32) { -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[VAR_empty_offsets_:%.+]] = tensor.empty() : tensor<1024xi32> +// CHECK-DAG: [[VAR_zero_offsets_:%.+]] = linalg.fill ins([[CST_0_]] : i32) outs([[VAR_empty_offsets_]] : tensor<1024xi32>) -> tensor<1024xi32> // CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1]>> // CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [1024], strides: [1] : memref<*xi32> to memref<1024xi32, strided<[1]>> // CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: [0], sizes: [1024], strides: [1] : memref<*xf16> to memref<1024xf16, strided<[1]>> @@ -87,30 +89,40 @@ module { // CHECK: [[VAR_10_4_:%.+]] = math.sqrt [[IN_8_]] : f32 // CHECK: linalg.yield [[VAR_10_4_]] : f32 // CHECK: } -> tensor<1024xf32> -// CHECK: [[VAR_cast_:%.+]] = memref.cast [[PARAM_3_]] : memref<*xbf16> to memref -// CHECK: affine.for [[I_0_:%.+]] = 0 to 1024 { -// CHECK: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_4_]]{{.}}[[I_0_]]{{.}} : tensor<1024xbf16> -// CHECK: memref.store [[VAR_extracted_]], [[VAR_cast_]]{{.}}[[CST_0_]]{{.}} : memref +// CHECK: [[VAR_cast_3_:%.+]] = memref.cast [[PARAM_3_]] : memref<*xbf16> to memref +// CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_zero_offsets_]], [[VAR_4_]] : tensor<1024xi32>, tensor<1024xbf16>) { +// CHECK: ^bb0([[IN_10_:%.+]]: i32, [[IN_11_:%.+]]: bf16): +// CHECK: [[VAR_11_:%.+]] = arith.index_cast [[IN_10_]] : i32 to index +// CHECK: memref.store [[IN_11_]], [[VAR_cast_3_]]{{.}}[[VAR_11_]]{{.}} : memref +// CHECK: linalg.yield // CHECK: } // CHECK: [[VAR_cast_4_:%.+]] = memref.cast [[PARAM_4_]] : memref<*xf32> to memref -// CHECK: affine.for [[I_1_:%.+]] = 0 to 1024 { -// CHECK: [[VAR_extracted_1_:%.+]] = tensor.extract [[VAR_5_]]{{.}}[[I_1_]]{{.}} : tensor<1024xf32> -// CHECK: memref.store [[VAR_extracted_1_]], [[VAR_cast_4_]]{{.}}[[CST_0_]]{{.}} : memref +// CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_zero_offsets_]], [[VAR_5_]] : tensor<1024xi32>, tensor<1024xf32>) { +// CHECK: ^bb0([[IN_12_:%.+]]: i32, [[IN_13_:%.+]]: f32): +// CHECK: [[VAR_12_:%.+]] = arith.index_cast [[IN_12_]] : i32 to index +// CHECK: memref.store [[IN_13_]], [[VAR_cast_4_]]{{.}}[[VAR_12_]]{{.}} : memref +// CHECK: linalg.yield // CHECK: } // CHECK: [[VAR_cast_5_:%.+]] = memref.cast [[PARAM_5_]] : memref<*xf32> to memref -// CHECK: affine.for [[I_2_:%.+]] = 0 to 1024 { -// CHECK: [[VAR_extracted_2_:%.+]] = tensor.extract [[VAR_7_]]{{.}}[[I_2_]]{{.}} : tensor<1024xf32> -// CHECK: memref.store [[VAR_extracted_2_]], [[VAR_cast_5_]]{{.}}[[CST_0_]]{{.}} : memref +// CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_zero_offsets_]], [[VAR_7_]] : tensor<1024xi32>, tensor<1024xf32>) { +// CHECK: ^bb0([[IN_14_:%.+]]: i32, [[IN_15_:%.+]]: f32): +// CHECK: [[VAR_13_:%.+]] = arith.index_cast [[IN_14_]] : i32 to index +// CHECK: memref.store [[IN_15_]], [[VAR_cast_5_]]{{.}}[[VAR_13_]]{{.}} : memref +// CHECK: linalg.yield // CHECK: } // CHECK: [[VAR_cast_6_:%.+]] = memref.cast [[PARAM_6_]] : memref<*xf32> to memref -// CHECK: affine.for [[I_3_:%.+]] = 0 to 1024 { -// CHECK: [[VAR_extracted_3_:%.+]] = tensor.extract [[VAR_8_]]{{.}}[[I_3_]]{{.}} : tensor<1024xf32> -// CHECK: memref.store [[VAR_extracted_3_]], [[VAR_cast_6_]]{{.}}[[CST_0_]]{{.}} : memref +// CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_zero_offsets_]], [[VAR_8_]] : tensor<1024xi32>, tensor<1024xf32>) { +// CHECK: ^bb0([[IN_16_:%.+]]: i32, [[IN_17_:%.+]]: f32): +// CHECK: [[VAR_14_:%.+]] = arith.index_cast [[IN_16_]] : i32 to index +// CHECK: memref.store [[IN_17_]], [[VAR_cast_6_]]{{.}}[[VAR_14_]]{{.}} : memref +// CHECK: linalg.yield // CHECK: } // CHECK: [[VAR_cast_7_:%.+]] = memref.cast [[PARAM_7_]] : memref<*xf32> to memref -// CHECK: affine.for [[I_4_:%.+]] = 0 to 1024 { -// CHECK: [[VAR_extracted_4_:%.+]] = tensor.extract [[VAR_9_]]{{.}}[[I_4_]]{{.}} : tensor<1024xf32> -// CHECK: memref.store [[VAR_extracted_4_]], [[VAR_cast_7_]]{{.}}[[CST_0_]]{{.}} : memref +// CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_zero_offsets_]], [[VAR_9_]] : tensor<1024xi32>, tensor<1024xf32>) { +// CHECK: ^bb0([[IN_18_:%.+]]: i32, [[IN_19_:%.+]]: f32): +// CHECK: [[VAR_15_:%.+]] = arith.index_cast [[IN_18_]] : i32 to index +// CHECK: memref.store [[IN_19_]], [[VAR_cast_7_]]{{.}}[[VAR_15_]]{{.}} : memref +// CHECK: linalg.yield // CHECK: } // CHECK: return // CHECK: } diff --git a/test/Conversion/StructuredToMemref/convert_2d_elemwise_arith_binary.mlir b/test/Conversion/StructuredToMemref/convert_2d_elemwise_arith_binary.mlir index f20d0ba3..1365a569 100644 --- a/test/Conversion/StructuredToMemref/convert_2d_elemwise_arith_binary.mlir +++ b/test/Conversion/StructuredToMemref/convert_2d_elemwise_arith_binary.mlir @@ -36,7 +36,9 @@ module { // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func.func @kernel // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: memref<*xf32>, [[PARAM_3_:%.+]]: memref<*xf32>, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32) { -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[VAR_empty_offsets_:%.+]] = tensor.empty() : tensor<128x128xi32> +// CHECK-DAG: [[VAR_zero_offsets_:%.+]] = linalg.fill ins([[CST_0_]] : i32) outs([[VAR_empty_offsets_]] : tensor<128x128xi32>) -> tensor<128x128xi32> // CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [128, 128], strides: [1, 1] : memref<*xf32> to memref<128x128xf32, strided<[1, 1]>> // CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [128, 128], strides: [1, 1] : memref<*xf32> to memref<128x128xf32, strided<[1, 1]>> // CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<128x128xf32> @@ -55,18 +57,19 @@ module { // CHECK: [[VAR_4_1_:%.+]] = arith.subf [[IN_3_]], [[IN_4_]] : f32 // CHECK: linalg.yield [[VAR_4_1_]] : f32 // CHECK: } -> tensor<128x128xf32> -// CHECK: [[VAR_cast_:%.+]] = memref.cast [[PARAM_2_]] : memref<*xf32> to memref -// CHECK: affine.for [[I_0_:%.+]] = 0 to 128 { -// CHECK: affine.for [[I_1_:%.+]] = 0 to 128 { -// CHECK: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_2_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : tensor<128x128xf32> -// CHECK: memref.store [[VAR_extracted_]], [[VAR_cast_]]{{.}}[[CST_0_]]{{.}} : memref -// CHECK: } +// CHECK: [[VAR_cast_2_:%.+]] = memref.cast [[PARAM_2_]] : memref<*xf32> to memref +// CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_zero_offsets_]], [[VAR_2_]] : tensor<128x128xi32>, tensor<128x128xf32>) { +// CHECK: ^bb0([[IN_6_:%.+]]: i32, [[IN_7_:%.+]]: f32): +// CHECK: [[VAR_5_:%.+]] = arith.index_cast [[IN_6_]] : i32 to index +// CHECK: memref.store [[IN_7_]], [[VAR_cast_2_]]{{.}}[[VAR_5_]]{{.}} : memref +// CHECK: linalg.yield // CHECK: } -// CHECK: [[VAR_cast_2_:%.+]] = memref.cast [[PARAM_3_]] : memref<*xf32> to memref -// CHECK: affine.for [[I_2_:%.+]] = 0 to 128 { -// CHECK: affine.for [[I_3_:%.+]] = 0 to 128 { -// CHECK: [[VAR_extracted_1_:%.+]] = tensor.extract [[VAR_3_]]{{.}}[[I_2_]], [[I_3_]]{{.}} : tensor<128x128xf32> -// CHECK: memref.store [[VAR_extracted_1_]], [[VAR_cast_2_]]{{.}}[[CST_0_]]{{.}} : memref -// CHECK: } +// CHECK: [[VAR_cast_3_:%.+]] = memref.cast [[PARAM_3_]] : memref<*xf32> to memref +// CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_zero_offsets_]], [[VAR_3_]] : tensor<128x128xi32>, tensor<128x128xf32>) { +// CHECK: ^bb0([[IN_8_:%.+]]: i32, [[IN_9_:%.+]]: f32): +// CHECK: [[VAR_6_:%.+]] = arith.index_cast [[IN_8_]] : i32 to index +// CHECK: memref.store [[IN_9_]], [[VAR_cast_3_]]{{.}}[[VAR_6_]]{{.}} : memref +// CHECK: linalg.yield // CHECK: } // CHECK: return +// CHECK: } diff --git a/test/Conversion/StructuredToMemref/convert_2d_elemwise_arith_ternary.mlir b/test/Conversion/StructuredToMemref/convert_2d_elemwise_arith_ternary.mlir index 12df1a9f..ad8b9533 100644 --- a/test/Conversion/StructuredToMemref/convert_2d_elemwise_arith_ternary.mlir +++ b/test/Conversion/StructuredToMemref/convert_2d_elemwise_arith_ternary.mlir @@ -37,7 +37,9 @@ module { // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func.func @kernel // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xi1>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: memref<*xf32>, [[PARAM_3_:%.+]]: memref<*xf32>, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32) { -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[VAR_empty_offsets_:%.+]] = tensor.empty() : tensor<128x128xi32> +// CHECK-DAG: [[VAR_zero_offsets_:%.+]] = linalg.fill ins([[CST_0_]] : i32) outs([[VAR_empty_offsets_]] : tensor<128x128xi32>) -> tensor<128x128xi32> // CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [128, 128], strides: [1, 1] : memref<*xi1> to memref<128x128xi1, strided<[1, 1]>> // CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [128, 128], strides: [1, 1] : memref<*xf32> to memref<128x128xf32, strided<[1, 1]>> // CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: [0], sizes: [128, 128], strides: [1, 1] : memref<*xf32> to memref<128x128xf32, strided<[1, 1]>> @@ -56,11 +58,11 @@ module { // CHECK: linalg.yield [[VAR_4_]] : f32 // CHECK: } -> tensor<128x128xf32> // CHECK: [[VAR_cast_:%.+]] = memref.cast [[PARAM_3_]] : memref<*xf32> to memref -// CHECK: affine.for [[I_0_:%.+]] = 0 to 128 { -// CHECK: affine.for [[I_1_:%.+]] = 0 to 128 { -// CHECK: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_3_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : tensor<128x128xf32> -// CHECK: memref.store [[VAR_extracted_]], [[VAR_cast_]]{{.}}[[CST_0_]]{{.}} : memref -// CHECK: } +// CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_zero_offsets_]], [[VAR_3_]] : tensor<128x128xi32>, tensor<128x128xf32>) { +// CHECK: ^bb0([[IN_4_:%.+]]: i32, [[IN_5_:%.+]]: f32): +// CHECK: [[VAR_5_:%.+]] = arith.index_cast [[IN_4_]] : i32 to index +// CHECK: memref.store [[IN_5_]], [[VAR_cast_]]{{.}}[[VAR_5_]]{{.}} : memref +// CHECK: linalg.yield // CHECK: } // CHECK: return // CHECK: } diff --git a/test/Conversion/StructuredToMemref/convert_2d_elemwise_arith_unary.mlir b/test/Conversion/StructuredToMemref/convert_2d_elemwise_arith_unary.mlir index ae190d00..128c9c60 100644 --- a/test/Conversion/StructuredToMemref/convert_2d_elemwise_arith_unary.mlir +++ b/test/Conversion/StructuredToMemref/convert_2d_elemwise_arith_unary.mlir @@ -53,7 +53,9 @@ module { // CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func.func @kernel // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xi32>, [[PARAM_2_:%.+]]: memref<*xf16>, [[PARAM_3_:%.+]]: memref<*xbf16>, [[PARAM_4_:%.+]]: memref<*xf32>, [[PARAM_5_:%.+]]: memref<*xf32>, [[PARAM_6_:%.+]]: memref<*xf32>, [[PARAM_7_:%.+]]: memref<*xf32>, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32, [[PARAM_10_:%.+]]: i32, [[PARAM_11_:%.+]]: i32, [[PARAM_12_:%.+]]: i32, [[PARAM_13_:%.+]]: i32) { -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[VAR_empty_offsets_:%.+]] = tensor.empty() : tensor<128x128xi32> +// CHECK-DAG: [[VAR_zero_offsets_:%.+]] = linalg.fill ins([[CST_0_]] : i32) outs([[VAR_empty_offsets_]] : tensor<128x128xi32>) -> tensor<128x128xi32> // CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [128, 128], strides: [1, 1] : memref<*xf32> to memref<128x128xf32, strided<[1, 1]>> // CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [128, 128], strides: [1, 1] : memref<*xi32> to memref<128x128xi32, strided<[1, 1]>> // CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: [0], sizes: [128, 128], strides: [1, 1] : memref<*xf16> to memref<128x128xf16, strided<[1, 1]>> @@ -93,40 +95,40 @@ module { // CHECK: [[VAR_10_4_:%.+]] = math.sqrt [[IN_8_]] : f32 // CHECK: linalg.yield [[VAR_10_4_]] : f32 // CHECK: } -> tensor<128x128xf32> -// CHECK: [[VAR_cast_:%.+]] = memref.cast [[PARAM_3_]] : memref<*xbf16> to memref -// CHECK: affine.for [[I_0_:%.+]] = 0 to 128 { -// CHECK: affine.for [[I_1_:%.+]] = 0 to 128 { -// CHECK: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_4_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : tensor<128x128xbf16> -// CHECK: memref.store [[VAR_extracted_]], [[VAR_cast_]]{{.}}[[CST_0_]]{{.}} : memref -// CHECK: } +// CHECK: [[VAR_cast_3_:%.+]] = memref.cast [[PARAM_3_]] : memref<*xbf16> to memref +// CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_zero_offsets_]], [[VAR_4_]] : tensor<128x128xi32>, tensor<128x128xbf16>) { +// CHECK: ^bb0([[IN_11_:%.+]]: i32, [[IN_12_:%.+]]: bf16): +// CHECK: [[VAR_11_:%.+]] = arith.index_cast [[IN_11_]] : i32 to index +// CHECK: memref.store [[IN_12_]], [[VAR_cast_3_]]{{.}}[[VAR_11_]]{{.}} : memref +// CHECK: linalg.yield // CHECK: } // CHECK: [[VAR_cast_4_:%.+]] = memref.cast [[PARAM_4_]] : memref<*xf32> to memref -// CHECK: affine.for [[I_2_:%.+]] = 0 to 128 { -// CHECK: affine.for [[I_3_:%.+]] = 0 to 128 { -// CHECK: [[VAR_extracted_1_:%.+]] = tensor.extract [[VAR_5_]]{{.}}[[I_2_]], [[I_3_]]{{.}} : tensor<128x128xf32> -// CHECK: memref.store [[VAR_extracted_1_]], [[VAR_cast_4_]]{{.}}[[CST_0_]]{{.}} : memref -// CHECK: } +// CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_zero_offsets_]], [[VAR_5_]] : tensor<128x128xi32>, tensor<128x128xf32>) { +// CHECK: ^bb0([[IN_13_:%.+]]: i32, [[IN_14_:%.+]]: f32): +// CHECK: [[VAR_12_:%.+]] = arith.index_cast [[IN_13_]] : i32 to index +// CHECK: memref.store [[IN_14_]], [[VAR_cast_4_]]{{.}}[[VAR_12_]]{{.}} : memref +// CHECK: linalg.yield // CHECK: } // CHECK: [[VAR_cast_5_:%.+]] = memref.cast [[PARAM_5_]] : memref<*xf32> to memref -// CHECK: affine.for [[I_4_:%.+]] = 0 to 128 { -// CHECK: affine.for [[I_5_:%.+]] = 0 to 128 { -// CHECK: [[VAR_extracted_2_:%.+]] = tensor.extract [[VAR_7_]]{{.}}[[I_4_]], [[I_5_]]{{.}} : tensor<128x128xf32> -// CHECK: memref.store [[VAR_extracted_2_]], [[VAR_cast_5_]]{{.}}[[CST_0_]]{{.}} : memref -// CHECK: } +// CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_zero_offsets_]], [[VAR_7_]] : tensor<128x128xi32>, tensor<128x128xf32>) { +// CHECK: ^bb0([[IN_15_:%.+]]: i32, [[IN_16_:%.+]]: f32): +// CHECK: [[VAR_13_:%.+]] = arith.index_cast [[IN_15_]] : i32 to index +// CHECK: memref.store [[IN_16_]], [[VAR_cast_5_]]{{.}}[[VAR_13_]]{{.}} : memref +// CHECK: linalg.yield // CHECK: } // CHECK: [[VAR_cast_6_:%.+]] = memref.cast [[PARAM_6_]] : memref<*xf32> to memref -// CHECK: affine.for [[I_6_:%.+]] = 0 to 128 { -// CHECK: affine.for [[I_7_:%.+]] = 0 to 128 { -// CHECK: [[VAR_extracted_3_:%.+]] = tensor.extract [[VAR_8_]]{{.}}[[I_6_]], [[I_7_]]{{.}} : tensor<128x128xf32> -// CHECK: memref.store [[VAR_extracted_3_]], [[VAR_cast_6_]]{{.}}[[CST_0_]]{{.}} : memref -// CHECK: } +// CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_zero_offsets_]], [[VAR_8_]] : tensor<128x128xi32>, tensor<128x128xf32>) { +// CHECK: ^bb0([[IN_17_:%.+]]: i32, [[IN_18_:%.+]]: f32): +// CHECK: [[VAR_14_:%.+]] = arith.index_cast [[IN_17_]] : i32 to index +// CHECK: memref.store [[IN_18_]], [[VAR_cast_6_]]{{.}}[[VAR_14_]]{{.}} : memref +// CHECK: linalg.yield // CHECK: } // CHECK: [[VAR_cast_7_:%.+]] = memref.cast [[PARAM_7_]] : memref<*xf32> to memref -// CHECK: affine.for [[I_8_:%.+]] = 0 to 128 { -// CHECK: affine.for [[I_9_:%.+]] = 0 to 128 { -// CHECK: [[VAR_extracted_4_:%.+]] = tensor.extract [[VAR_9_]]{{.}}[[I_8_]], [[I_9_]]{{.}} : tensor<128x128xf32> -// CHECK: memref.store [[VAR_extracted_4_]], [[VAR_cast_7_]]{{.}}[[CST_0_]]{{.}} : memref -// CHECK: } +// CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins([[VAR_zero_offsets_]], [[VAR_9_]] : tensor<128x128xi32>, tensor<128x128xf32>) { +// CHECK: ^bb0([[IN_19_:%.+]]: i32, [[IN_20_:%.+]]: f32): +// CHECK: [[VAR_15_:%.+]] = arith.index_cast [[IN_19_]] : i32 to index +// CHECK: memref.store [[IN_20_]], [[VAR_cast_7_]]{{.}}[[VAR_15_]]{{.}} : memref +// CHECK: linalg.yield // CHECK: } // CHECK: return // CHECK: } diff --git a/test/Conversion/StructuredToMemref/convert_splat_float.mlir b/test/Conversion/StructuredToMemref/convert_splat_float.mlir index da6b5296..1e94ff96 100644 --- a/test/Conversion/StructuredToMemref/convert_splat_float.mlir +++ b/test/Conversion/StructuredToMemref/convert_splat_float.mlir @@ -15,18 +15,32 @@ module { } } -// CHECK-LABEL: func.func @kernel -// CHECK-SAME: ([[PARAM_0_:%.+]]: f32, [[PARAM_1_:%.+]]: bf16, [[PARAM_2_:%.+]]: memref<*xf32>, [[PARAM_3_:%.+]]: memref<*xbf16>, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32) { -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index -// CHECK-DAG: [[VAR_cast_:%.+]] = memref.cast [[PARAM_2_]] : memref<*xf32> to memref -// CHECK: affine.for [[I_0_:%.+]] = 0 to 1024 { -// CHECK: memref.store [[PARAM_0_]], [[VAR_cast_]]{{.}}[[CST_0_]]{{.}} : memref +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> +// CHECK-DAG: [[MAP_1_:#.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func.func @kernel +// CHECK-SAME: ([[PARAM_0_:%.+]]: f32, [[PARAM_1_:%.+]]: bf16, [[PARAM_2_:%.+]]: memref<*xf32>, [[PARAM_3_:%.+]]: memref<*xbf16>, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32) { +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[VAR_empty_offsets_1d_:%.+]] = tensor.empty() : tensor<1024xi32> +// CHECK-DAG: [[VAR_zero_offsets_1d_:%.+]] = linalg.fill ins([[CST_0_]] : i32) outs([[VAR_empty_offsets_1d_]] : tensor<1024xi32>) -> tensor<1024xi32> +// CHECK-DAG: [[VAR_empty_offsets_2d_:%.+]] = tensor.empty() : tensor<128x256xi32> +// CHECK-DAG: [[VAR_zero_offsets_2d_:%.+]] = linalg.fill ins([[CST_0_]] : i32) outs([[VAR_empty_offsets_2d_]] : tensor<128x256xi32>) -> tensor<128x256xi32> +// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<1024xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = linalg.fill ins([[PARAM_0_]] : f32) outs([[VAR_0_]] : tensor<1024xf32>) -> tensor<1024xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = tensor.empty() : tensor<128x256xbf16> +// CHECK-DAG: [[VAR_3_:%.+]] = linalg.fill ins([[PARAM_1_]] : bf16) outs([[VAR_2_]] : tensor<128x256xbf16>) -> tensor<128x256xbf16> +// CHECK: [[VAR_cast_2_:%.+]] = memref.cast [[PARAM_2_]] : memref<*xf32> to memref +// CHECK: linalg.generic {indexing_maps = [[[MAP_0_]], [[MAP_0_]]], iterator_types = ["parallel"]} ins([[VAR_zero_offsets_1d_]], [[VAR_1_]] : tensor<1024xi32>, tensor<1024xf32>) { +// CHECK: ^bb0([[IN_0_:%.+]]: i32, [[IN_1_:%.+]]: f32): +// CHECK: [[VAR_4_:%.+]] = arith.index_cast [[IN_0_]] : i32 to index +// CHECK: memref.store [[IN_1_]], [[VAR_cast_2_]]{{.}}[[VAR_4_]]{{.}} : memref +// CHECK: linalg.yield // CHECK: } -// CHECK: [[VAR_cast_0_:%.+]] = memref.cast [[PARAM_3_]] : memref<*xbf16> to memref -// CHECK: affine.for [[I_1_:%.+]] = 0 to 128 { -// CHECK: affine.for [[I_2_:%.+]] = 0 to 256 { -// CHECK: memref.store [[PARAM_1_]], [[VAR_cast_0_]]{{.}}[[CST_0_]]{{.}} : memref -// CHECK: } +// CHECK: [[VAR_cast_3_:%.+]] = memref.cast [[PARAM_3_]] : memref<*xbf16> to memref +// CHECK: linalg.generic {indexing_maps = [[[MAP_1_]], [[MAP_1_]]], iterator_types = ["parallel", "parallel"]} ins([[VAR_zero_offsets_2d_]], [[VAR_3_]] : tensor<128x256xi32>, tensor<128x256xbf16>) { +// CHECK: ^bb0([[IN_2_:%.+]]: i32, [[IN_3_:%.+]]: bf16): +// CHECK: [[VAR_5_:%.+]] = arith.index_cast [[IN_2_]] : i32 to index +// CHECK: memref.store [[IN_3_]], [[VAR_cast_3_]]{{.}}[[VAR_5_]]{{.}} : memref +// CHECK: linalg.yield // CHECK: } // CHECK: return // CHECK: } diff --git a/test/Conversion/StructuredToMemref/reducemax_32_256_bf16.mlir b/test/Conversion/StructuredToMemref/reducemax_32_256_bf16.mlir index 3a0a4033..e552d054 100644 --- a/test/Conversion/StructuredToMemref/reducemax_32_256_bf16.mlir +++ b/test/Conversion/StructuredToMemref/reducemax_32_256_bf16.mlir @@ -41,9 +41,12 @@ module { } } +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func.func @kernel // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xbf16>, [[PARAM_1_:%.+]]: memref<*xbf16>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) { -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[VAR_empty_offsets_:%.+]] = tensor.empty() : tensor<256x16xi32> +// CHECK-DAG: [[VAR_zero_offsets_:%.+]] = linalg.fill ins([[CST_0_]] : i32) outs([[VAR_empty_offsets_]] : tensor<256x16xi32>) -> tensor<256x16xi32> // CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0xFF80 : bf16 // CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index // CHECK-NOT: separator of consecutive DAGs @@ -59,11 +62,11 @@ module { // CHECK: linalg.yield [[VAR_3_]] : bf16 // CHECK: } // CHECK: [[VAR_cast_:%.+]] = memref.cast [[PARAM_1_]] : memref<*xbf16> to memref -// CHECK: affine.for [[I_0_:%.+]] = 0 to 256 { -// CHECK: affine.for [[I_1_:%.+]] = 0 to 16 { -// CHECK: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_reduced_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : tensor<256x16xbf16> -// CHECK: memref.store [[VAR_extracted_]], [[VAR_cast_]]{{.}}[[CST_0_]]{{.}} : memref -// CHECK: } +// CHECK: linalg.generic {indexing_maps = [[[MAP_0_]], [[MAP_0_]]], iterator_types = ["parallel", "parallel"]} ins([[VAR_zero_offsets_]], [[VAR_reduced_]] : tensor<256x16xi32>, tensor<256x16xbf16>) { +// CHECK: ^bb0([[IN_0_:%.+]]: i32, [[IN_1_:%.+]]: bf16): +// CHECK: [[VAR_4_:%.+]] = arith.index_cast [[IN_0_]] : i32 to index +// CHECK: memref.store [[IN_1_]], [[VAR_cast_]]{{.}}[[VAR_4_]]{{.}} : memref +// CHECK: linalg.yield // CHECK: } // CHECK: return // CHECK: } diff --git a/test/Conversion/StructuredToMemref/reducesum_middle_dim.mlir b/test/Conversion/StructuredToMemref/reducesum_middle_dim.mlir index 048369f3..810eeda3 100644 --- a/test/Conversion/StructuredToMemref/reducesum_middle_dim.mlir +++ b/test/Conversion/StructuredToMemref/reducesum_middle_dim.mlir @@ -41,9 +41,12 @@ module { } } +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func.func @kernel // CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xbf16>, [[PARAM_1_:%.+]]: memref<*xbf16>, [[PARAM_2_:%.+]]: memref<*xbf16>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32) { -// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[VAR_empty_offsets_:%.+]] = tensor.empty() : tensor<32x16xi32> +// CHECK-DAG: [[VAR_zero_offsets_:%.+]] = linalg.fill ins([[CST_0_]] : i32) outs([[VAR_empty_offsets_]] : tensor<32x16xi32>) -> tensor<32x16xi32> // CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : bf16 // CHECK-DAG: [[CST_256_:%.+]] = arith.constant 256 : index // CHECK-NOT: separator of consecutive DAGs @@ -59,11 +62,11 @@ module { // CHECK: linalg.yield [[VAR_3_]] : bf16 // CHECK: } // CHECK: [[VAR_cast_:%.+]] = memref.cast [[PARAM_2_]] : memref<*xbf16> to memref -// CHECK: affine.for [[I_0_:%.+]] = 0 to 32 { -// CHECK: affine.for [[I_1_:%.+]] = 0 to 16 { -// CHECK: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_reduced_]]{{.}}[[I_0_]], [[I_1_]]{{.}} : tensor<32x16xbf16> -// CHECK: memref.store [[VAR_extracted_]], [[VAR_cast_]]{{.}}[[CST_0_]]{{.}} : memref -// CHECK: } +// CHECK: linalg.generic {indexing_maps = [[[MAP_0_]], [[MAP_0_]]], iterator_types = ["parallel", "parallel"]} ins([[VAR_zero_offsets_]], [[VAR_reduced_]] : tensor<32x16xi32>, tensor<32x16xbf16>) { +// CHECK: ^bb0([[IN_0_:%.+]]: i32, [[IN_1_:%.+]]: bf16): +// CHECK: [[VAR_5_:%.+]] = arith.index_cast [[IN_0_]] : i32 to index +// CHECK: memref.store [[IN_1_]], [[VAR_cast_]]{{.}}[[VAR_5_]]{{.}} : memref +// CHECK: linalg.yield // CHECK: } // CHECK: return // CHECK: } diff --git a/test/Conversion/UnstructuredToMemref/gather_mask_no_other.mlir b/test/Conversion/UnstructuredToMemref/gather_mask_no_other.mlir index c8d5127b..6be58114 100644 --- a/test/Conversion/UnstructuredToMemref/gather_mask_no_other.mlir +++ b/test/Conversion/UnstructuredToMemref/gather_mask_no_other.mlir @@ -29,50 +29,52 @@ module { } } -// CHECK: #map = affine_map<(d0) -> (d0)> -// CHECK: module { -// CHECK: tt.func public @gather_simple_mask_no_other(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { -// CHECK: %cst = arith.constant 0.000000e+00 : f32 -// CHECK: %c8_i32 = arith.constant 8 : i32 -// CHECK: %cst_0 = arith.constant dense<4> : tensor<64xi32> -// CHECK: %c16_i32 = arith.constant 16 : i32 -// CHECK: %cst_1 = arith.constant dense<64> : tensor<64xi32> -// CHECK: %c2_i32 = arith.constant 2 : i32 -// CHECK: %c1_i32 = arith.constant 1 : i32 -// CHECK: %c0_i32 = arith.constant 0 : i32 -// CHECK: %0 = builtin.unrealized_conversion_cast %arg1 : !tt.ptr to memref<*xf32> -// CHECK: %1 = builtin.unrealized_conversion_cast %arg0 : !tt.ptr to memref<*xf32> -// CHECK: %2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> -// CHECK: %3:3 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %c8_i32, %arg4 = %2, %arg5 = %2) -> (i32, tensor<64xi32>, tensor<64xi32>) : i32 { -// CHECK: %4 = arith.divsi %arg4, %cst_0 : tensor<64xi32> -// CHECK: %5 = tt.splat %arg3 : i32 -> tensor<64xi32> -// CHECK: %6 = arith.cmpi slt, %4, %5 : tensor<64xi32> -// CHECK: %cast = memref.cast %1 : memref<*xf32> to memref -// CHECK: %7 = bufferization.to_tensor %cast restrict : memref -// CHECK: %8 = tensor.empty() : tensor<64xf32> -// CHECK: %9 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%4, %6 : tensor<64xi32>, tensor<64xi1>) outs(%8 : tensor<64xf32>) { -// CHECK: ^bb0(%in: i32, %in_3: i1, %out: f32): -// CHECK: %13 = scf.if %in_3 -> (f32) { -// CHECK: %14 = arith.index_cast %in : i32 to index -// CHECK: %extracted = tensor.extract %7[%14] : tensor -// CHECK: scf.yield %extracted : f32 -// CHECK: } else { -// CHECK: scf.yield %cst : f32 +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> +// CHECK: tt.func public @gather_simple_mask_no_other([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr) attributes {noinline = false} { +// CHECK-DAG: [[CST_0_dot_000000_:%.+]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: [[CST_8_:%.+]] = arith.constant 8 : i32 +// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<4> : tensor<64xi32> +// CHECK-DAG: [[CST_16_:%.+]] = arith.constant 16 : i32 +// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<64> : tensor<64xi32> +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32 +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[VAR_0_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_1_]] : !tt.ptr to memref<*xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_0_]] : !tt.ptr to memref<*xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_3_:%.+]]:3 = scf.for [[VAR_arg2_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_arg3_:%.+]] = [[CST_8_]], [[VAR_arg4_:%.+]] = [[VAR_2_]], [[VAR_arg5_:%.+]] = [[VAR_2_]]) -> (i32, tensor<64xi32>, tensor<64xi32>) : i32 { +// CHECK-DAG: [[VAR_4_:%.+]] = arith.divsi [[VAR_arg4_]], [[VAR_cst_]] : tensor<64xi32> +// CHECK-DAG: [[VAR_5_:%.+]] = tt.splat [[VAR_arg3_]] : i32 -> tensor<64xi32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_6_:%.+]] = arith.cmpi slt, [[VAR_4_]], [[VAR_5_]] : tensor<64xi32> +// CHECK-DAG: [[VAR_cast_:%.+]] = memref.cast [[VAR_1_]] : memref<*xf32> to memref +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_7_:%.+]] = bufferization.to_tensor [[VAR_cast_]] restrict : memref +// CHECK-DAG: [[VAR_8_:%.+]] = tensor.empty() : tensor<64xf32> +// CHECK: [[VAR_9_:%.+]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins([[VAR_4_]], [[VAR_6_]] : tensor<64xi32>, tensor<64xi1>) outs([[VAR_8_]] : tensor<64xf32>) { +// CHECK: ^bb0([[IN_0_:%.+]]: i32, [[IN_1_:%.+]]: i1, [[IN_2_:%.+]]: f32): +// CHECK-DAG: [[VAR_13_:%.+]] = scf.if [[IN_1_]] -> (f32) { +// CHECK-DAG: [[VAR_14_:%.+]] = arith.index_cast [[IN_0_]] : i32 to index +// CHECK: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_7_]]{{.}}[[VAR_14_]]{{.}} : tensor +// CHECK: scf.yield [[VAR_extracted_]] : f32 +// CHECK: } else { +// CHECK: scf.yield [[CST_0_dot_000000_]] : f32 +// CHECK: } +// CHECK: linalg.yield [[VAR_13_]] : f32 +// CHECK: } -> tensor<64xf32> +// CHECK: [[VAR_cast_2_:%.+]] = memref.cast [[VAR_0_]] : memref<*xf32> to memref +// CHECK: linalg.generic {indexing_maps = [[[MAP_0_]], [[MAP_0_]]], iterator_types = ["parallel"]} ins([[VAR_arg5_]], [[VAR_9_]] : tensor<64xi32>, tensor<64xf32>) { +// CHECK: ^bb0([[IN_3_:%.+]]: i32, [[IN_4_:%.+]]: f32): +// CHECK: [[VAR_12_:%.+]] = arith.index_cast [[IN_3_]] : i32 to index +// CHECK: memref.store [[IN_4_]], [[VAR_cast_2_]]{{.}}[[VAR_12_]]{{.}} : memref +// CHECK: linalg.yield +// CHECK: } +// CHECK-DAG: [[VAR_15_:%.+]] = arith.addi [[VAR_arg3_]], [[CST_16_]] : i32 +// CHECK-DAG: [[VAR_16_:%.+]] = arith.addi [[VAR_arg4_]], [[VAR_cst_0_]] : tensor<64xi32> +// CHECK-DAG: [[VAR_17_:%.+]] = arith.addi [[VAR_arg5_]], [[VAR_cst_0_]] : tensor<64xi32> +// CHECK: scf.yield [[VAR_15_]], [[VAR_16_]], [[VAR_17_]] : i32, tensor<64xi32>, tensor<64xi32> +// CHECK: } +// CHECK: tt.return // CHECK: } -// CHECK: linalg.yield %13 : f32 -// CHECK: } -> tensor<64xf32> -// CHECK: %cast_2 = memref.cast %0 : memref<*xf32> to memref -// CHECK: affine.for %arg6 = 0 to 64 { -// CHECK: %extracted = tensor.extract %arg5[%arg6] : tensor<64xi32> -// CHECK: %extracted_3 = tensor.extract %9[%arg6] : tensor<64xf32> -// CHECK: %13 = arith.index_cast %extracted : i32 to index -// CHECK: memref.store %extracted_3, %cast_2[%13] : memref -// CHECK: } -// CHECK: %10 = arith.addi %arg3, %c16_i32 : i32 -// CHECK: %11 = arith.addi %arg4, %cst_1 : tensor<64xi32> -// CHECK: %12 = arith.addi %arg5, %cst_1 : tensor<64xi32> -// CHECK: scf.yield %10, %11, %12 : i32, tensor<64xi32>, tensor<64xi32> -// CHECK: } -// CHECK: tt.return -// CHECK: } -// CHECK: } + diff --git a/test/Conversion/UnstructuredToMemref/gather_mask_with_other.mlir b/test/Conversion/UnstructuredToMemref/gather_mask_with_other.mlir index a6eb5cdf..c947f176 100644 --- a/test/Conversion/UnstructuredToMemref/gather_mask_with_other.mlir +++ b/test/Conversion/UnstructuredToMemref/gather_mask_with_other.mlir @@ -65,16 +65,16 @@ module { // CHECK: linalg.yield [[VAR_13_]] : f32 // CHECK: } -> tensor<64xf32> // CHECK: [[VAR_cast_2_:%.+]] = memref.cast [[VAR_0_]] : memref<*xf32> to memref -// CHECK: affine.for [[I_0_:%.+]] = 0 to 64 { -// CHECK-DAG: [[VAR_extracted_1_:%.+]] = tensor.extract [[VAR_arg5_]]{{.}}[[I_0_]]{{.}} : tensor<64xi32> -// CHECK-DAG: [[VAR_extracted_3_:%.+]] = tensor.extract [[VAR_9_]]{{.}}[[I_0_]]{{.}} : tensor<64xf32> -// CHECK: [[VAR_13_1_:%.+]] = arith.index_cast [[VAR_extracted_1_]] : i32 to index -// CHECK: memref.store [[VAR_extracted_3_]], [[VAR_cast_2_]]{{.}}[[VAR_13_1_]]{{.}} : memref +// CHECK: linalg.generic {indexing_maps = [[[MAP_0_]], [[MAP_0_]]], iterator_types = ["parallel"]} ins([[VAR_arg5_]], [[VAR_9_]] : tensor<64xi32>, tensor<64xf32>) { +// CHECK: ^bb0([[IN_3_:%.+]]: i32, [[IN_4_:%.+]]: f32): +// CHECK: [[VAR_12_:%.+]] = arith.index_cast [[IN_3_]] : i32 to index +// CHECK: memref.store [[IN_4_]], [[VAR_cast_2_]]{{.}}[[VAR_12_]]{{.}} : memref +// CHECK: linalg.yield // CHECK: } -// CHECK-DAG: [[VAR_10_:%.+]] = arith.addi [[VAR_arg3_]], [[CST_16_]] : i32 -// CHECK-DAG: [[VAR_11_:%.+]] = arith.addi [[VAR_arg4_]], [[VAR_cst_0_]] : tensor<64xi32> -// CHECK-DAG: [[VAR_12_:%.+]] = arith.addi [[VAR_arg5_]], [[VAR_cst_0_]] : tensor<64xi32> -// CHECK: scf.yield [[VAR_10_]], [[VAR_11_]], [[VAR_12_]] : i32, tensor<64xi32>, tensor<64xi32> +// CHECK-DAG: [[VAR_15_:%.+]] = arith.addi [[VAR_arg3_]], [[CST_16_]] : i32 +// CHECK-DAG: [[VAR_16_:%.+]] = arith.addi [[VAR_arg4_]], [[VAR_cst_0_]] : tensor<64xi32> +// CHECK-DAG: [[VAR_17_:%.+]] = arith.addi [[VAR_arg5_]], [[VAR_cst_0_]] : tensor<64xi32> +// CHECK: scf.yield [[VAR_15_]], [[VAR_16_]], [[VAR_17_]] : i32, tensor<64xi32>, tensor<64xi32> // CHECK: } // CHECK: tt.return // CHECK: } diff --git a/test/Conversion/UnstructuredToMemref/gather_no_mask.mlir b/test/Conversion/UnstructuredToMemref/gather_no_mask.mlir index 32ba5ce3..44120ef5 100644 --- a/test/Conversion/UnstructuredToMemref/gather_no_mask.mlir +++ b/test/Conversion/UnstructuredToMemref/gather_no_mask.mlir @@ -30,45 +30,44 @@ module { } } -// CHECK: #map = affine_map<(d0) -> (d0)> -// CHECK: module { -// CHECK: tt.func public @gather_simple_no_mask(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { -// CHECK: %cst = arith.constant dense<10> : tensor<64xi32> -// CHECK: %c5_i32 = arith.constant 5 : i32 -// CHECK: %c64_i32 = arith.constant 64 : i32 -// CHECK: %cst_0 = arith.constant dense<64> : tensor<64xi32> -// CHECK: %c2_i32 = arith.constant 2 : i32 -// CHECK: %c1_i32 = arith.constant 1 : i32 -// CHECK: %c0_i32 = arith.constant 0 : i32 -// CHECK: %0 = builtin.unrealized_conversion_cast %arg1 : !tt.ptr to memref<*xf32> -// CHECK: %1 = builtin.unrealized_conversion_cast %arg0 : !tt.ptr to memref<*xf32> -// CHECK: %2 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> -// CHECK: %3:2 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %2, %arg4 = %2) -> (tensor<64xi32>, tensor<64xi32>) : i32 { -// CHECK: %4 = arith.divsi %arg3, %cst : tensor<64xi32> -// CHECK: %5 = arith.addi %arg2, %c5_i32 : i32 -// CHECK: %6 = arith.remsi %5, %c64_i32 : i32 -// CHECK: %7 = tt.splat %6 : i32 -> tensor<64xi32> -// CHECK: %8 = arith.addi %4, %7 : tensor<64xi32> -// CHECK: %cast = memref.cast %1 : memref<*xf32> to memref -// CHECK: %9 = bufferization.to_tensor %cast restrict : memref -// CHECK: %10 = tensor.empty() : tensor<64xf32> -// CHECK: %11 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%8 : tensor<64xi32>) outs(%10 : tensor<64xf32>) { -// CHECK: ^bb0(%in: i32, %out: f32): -// CHECK: %14 = arith.index_cast %in : i32 to index -// CHECK: %extracted = tensor.extract %9[%14] : tensor -// CHECK: linalg.yield %extracted : f32 -// CHECK: } -> tensor<64xf32> -// CHECK: %cast_1 = memref.cast %0 : memref<*xf32> to memref -// CHECK: affine.for %arg5 = 0 to 64 { -// CHECK: %extracted = tensor.extract %arg4[%arg5] : tensor<64xi32> -// CHECK: %extracted_2 = tensor.extract %11[%arg5] : tensor<64xf32> -// CHECK: %14 = arith.index_cast %extracted : i32 to index -// CHECK: memref.store %extracted_2, %cast_1[%14] : memref -// CHECK: } -// CHECK: %12 = arith.addi %8, %cst_0 : tensor<64xi32> -// CHECK: %13 = arith.addi %arg4, %cst_0 : tensor<64xi32> -// CHECK: scf.yield %12, %13 : tensor<64xi32>, tensor<64xi32> -// CHECK: } -// CHECK: tt.return -// CHECK: } -// CHECK: } +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)> +// CHECK-LABEL: tt.func public @gather_simple_no_mask +// CHECK-SAME: ([[PARAM_0_:%.+]]: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr) attributes {noinline = false} { +// CHECK-DAG: [[CST_10s_:%.+]] = arith.constant dense<10> : tensor<64xi32> +// CHECK-DAG: [[CST_5_:%.+]] = arith.constant 5 : i32 +// CHECK-DAG: [[CST_64_:%.+]] = arith.constant 64 : i32 +// CHECK-DAG: [[CST_64s_:%.+]] = arith.constant dense<64> : tensor<64xi32> +// CHECK-DAG: [[CST_2_:%.+]] = arith.constant 2 : i32 +// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i32 +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[VAR_0_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_1_]] : !tt.ptr to memref<*xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_0_]] : !tt.ptr to memref<*xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> +// CHECK: [[VAR_3_:%.+]]:2 = scf.for [[VAR_4_:%.+]] = [[CST_0_]] to [[CST_2_]] step [[CST_1_]] iter_args([[VAR_5_:%.+]] = [[VAR_2_]], [[VAR_6_:%.+]] = [[VAR_2_]]) -> (tensor<64xi32>, tensor<64xi32>) : i32 { +// CHECK: [[VAR_7_:%.+]] = arith.divsi [[VAR_5_]], [[CST_10s_]] : tensor<64xi32> +// CHECK: [[VAR_8_:%.+]] = arith.addi [[VAR_4_]], [[CST_5_]] : i32 +// CHECK: [[VAR_9_:%.+]] = arith.remsi [[VAR_8_]], [[CST_64_]] : i32 +// CHECK: [[VAR_10_:%.+]] = tt.splat [[VAR_9_]] : i32 -> tensor<64xi32> +// CHECK: [[VAR_11_:%.+]] = arith.addi [[VAR_7_]], [[VAR_10_]] : tensor<64xi32> +// CHECK: [[VAR_12_:%.+]] = memref.cast [[VAR_1_]] : memref<*xf32> to memref +// CHECK: [[VAR_13_:%.+]] = bufferization.to_tensor [[VAR_12_]] restrict : memref to tensor +// CHECK: [[VAR_14_:%.+]] = tensor.empty() : tensor<64xf32> +// CHECK: [[VAR_15_:%.+]] = linalg.generic {indexing_maps = [[[MAP_0_]], [[MAP_0_]]], iterator_types = ["parallel"]} ins([[VAR_11_]] : tensor<64xi32>) outs([[VAR_14_]] : tensor<64xf32>) { +// CHECK: ^bb0([[IN_0_:%.+]]: i32, [[IN_1_:%.+]]: f32): +// CHECK: [[VAR_16_:%.+]] = arith.index_cast [[IN_0_]] : i32 to index +// CHECK: [[VAR_17_:%.+]] = tensor.extract [[VAR_13_]]{{.}}[[VAR_16_]]{{.}} : tensor +// CHECK: linalg.yield [[VAR_17_]] : f32 +// CHECK: } -> tensor<64xf32> +// CHECK: [[VAR_18_:%.+]] = memref.cast [[VAR_0_]] : memref<*xf32> to memref +// CHECK: linalg.generic {indexing_maps = [[[MAP_0_]], [[MAP_0_]]], iterator_types = ["parallel"]} ins([[VAR_6_]], [[VAR_15_]] : tensor<64xi32>, tensor<64xf32>) { +// CHECK: ^bb0([[IN_2_:%.+]]: i32, [[IN_3_:%.+]]: f32): +// CHECK: [[VAR_19_:%.+]] = arith.index_cast [[IN_2_]] : i32 to index +// CHECK: memref.store [[IN_3_]], [[VAR_18_]]{{.}}[[VAR_19_]]{{.}} : memref +// CHECK: linalg.yield +// CHECK: } +// CHECK: [[VAR_20_:%.+]] = arith.addi [[VAR_11_]], [[CST_64s_]] : tensor<64xi32> +// CHECK: [[VAR_21_:%.+]] = arith.addi [[VAR_6_]], [[CST_64s_]] : tensor<64xi32> +// CHECK: scf.yield [[VAR_20_]], [[VAR_21_]] : tensor<64xi32>, tensor<64xi32> +// CHECK: } +// CHECK: tt.return +// CHECK: } \ No newline at end of file diff --git a/test/Conversion/UnstructuredToMemref/gather_scatter_all_mask.mlir b/test/Conversion/UnstructuredToMemref/gather_scatter_all_mask.mlir index 864503f0..a4c5f411 100644 --- a/test/Conversion/UnstructuredToMemref/gather_scatter_all_mask.mlir +++ b/test/Conversion/UnstructuredToMemref/gather_scatter_all_mask.mlir @@ -63,14 +63,13 @@ module { // CHECK: linalg.yield [[VAR_13_]] : f32 // CHECK: } -> tensor<4xf32> // CHECK: [[VAR_cast_3_:%.+]] = memref.cast [[VAR_0_]] : memref<*xf32> to memref -// CHECK: affine.for [[I_0_:%.+]] = 0 to 4 { -// CHECK: [[VAR_extracted_1_:%.+]] = tensor.extract [[VAR_7_]]{{.}}[[I_0_]]{{.}} : tensor<4xi1> -// CHECK: scf.if [[VAR_extracted_1_]] { -// CHECK-DAG: [[VAR_extracted_4_:%.+]] = tensor.extract [[VAR_6_]]{{.}}[[I_0_]]{{.}} : tensor<4xi32> -// CHECK-DAG: [[VAR_extracted_5_:%.+]] = tensor.extract [[VAR_10_]]{{.}}[[I_0_]]{{.}} : tensor<4xf32> -// CHECK: [[VAR_13_1_:%.+]] = arith.index_cast [[VAR_extracted_4_]] : i32 to index -// CHECK: memref.store [[VAR_extracted_5_]], [[VAR_cast_3_]]{{.}}[[VAR_13_1_]]{{.}} : memref +// CHECK: linalg.generic {indexing_maps = [[[MAP_0_]], [[MAP_0_]], [[MAP_0_]]], iterator_types = ["parallel"]} ins([[VAR_6_]], [[VAR_10_]], [[VAR_7_]] : tensor<4xi32>, tensor<4xf32>, tensor<4xi1>) { +// CHECK: ^bb0([[IN_3_:%.+]]: i32, [[IN_4_:%.+]]: f32, [[IN_5_:%.+]]: i1): +// CHECK: scf.if [[IN_5_]] { +// CHECK: [[VAR_17_:%.+]] = arith.index_cast [[IN_3_]] : i32 to index +// CHECK: memref.store [[IN_4_]], [[VAR_cast_3_]]{{.}}[[VAR_17_]]{{.}} : memref // CHECK: } +// CHECK: linalg.yield // CHECK: } // CHECK-DAG: [[VAR_11_:%.+]] = arith.addi [[VAR_6_]], [[VAR_cst_1_]] : tensor<4xi32> // CHECK-DAG: [[VAR_12_:%.+]] = arith.addi [[VAR_arg4_]], [[VAR_cst_1_]] : tensor<4xi32>