Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/BuildOnLinuxOSX.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Firstly, install MLIR (as a part of LLVM-Project):
``` bash
git clone -n https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX-MLIR.
cd llvm-project && git checkout e86910337f98e57f5b9253f7d80d5b916eb1d97e && cd ..
cd llvm-project && git checkout b270525f730be6e7196667925f5a9bfa153262e9 && cd ..
```

[same-as-file]: <> (utils/build-mlir.sh)
Expand Down
2 changes: 1 addition & 1 deletion docs/BuildOnWindows.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Install MLIR (as a part of LLVM-Project):
```shell
git clone -n https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX-MLIR.
cd llvm-project && git checkout e86910337f98e57f5b9253f7d80d5b916eb1d97e && cd ..
cd llvm-project && git checkout b270525f730be6e7196667925f5a9bfa153262e9 && cd ..
```

[same-as-file]: <> (utils/build-mlir.cmd)
Expand Down
13 changes: 12 additions & 1 deletion src/Conversion/KrnlToLLVM/ConvertKrnlToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,23 @@ void populateAffineAndKrnlToLLVMConversion(RewritePatternSet &patterns,
// They run it in two steps, and add additional lowerings.

vector::populateVectorToVectorCanonicalizationPatterns(patterns);
vector::populateVectorBitCastLoweringPatterns(patterns);
vector::populateVectorBroadcastLoweringPatterns(patterns);
vector::populateVectorContractLoweringPatterns(
patterns, vector::VectorTransformsOptions());
vector::populateVectorMaskOpLoweringPatterns(patterns);
vector::populateVectorShapeCastLoweringPatterns(patterns);
vector::populateVectorInterleaveLoweringPatterns(patterns);
vector::populateVectorTransposeLoweringPatterns(
patterns, vector::VectorTransformsOptions());
vector::populateVectorShapeCastLoweringPatterns(patterns);
// Vector transfer ops with rank > 1 should be lowered with VectorToSCF.
vector::populateVectorTransferLoweringPatterns(
patterns, /*maxTransferRank=*/1);
vector::populateVectorMaskMaterializationPatterns(
patterns, /*force32BitVectorIndices*/ false);
vector::populateVectorInsertExtractStridedSliceTransforms(patterns);
vector::populateVectorStepLoweringPatterns(patterns);
vector::populateVectorRankReducingFMAPattern(patterns);

populateAffineToStdConversionPatterns(patterns);
populateSCFToControlFlowConversionPatterns(patterns);
Expand Down
12 changes: 8 additions & 4 deletions src/Conversion/ONNXToTOSA/Math/Conv2D.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Value createConvInGroups(PatternRewriter &rewriter, Operation *op,
const llvm::ArrayRef<int64_t> weightShape, Value &newInput,
Value &newWeight, Value &bias, const int64_t groups,
DenseI64ArrayAttr &pads, DenseI64ArrayAttr &strides,
DenseI64ArrayAttr &dilations) {
DenseI64ArrayAttr &dilations, TypeAttr &accType) {
// Set up constants outside of loop
const int64_t sizeOfSliceInput = weightShape[1];
const int64_t sizeOfSliceKernel = weightShape[0] / groups;
Expand Down Expand Up @@ -72,7 +72,7 @@ Value createConvInGroups(PatternRewriter &rewriter, Operation *op,
mlir::cast<ShapedType>(resultType).getElementType());
Value tempConv2D = tosa::CreateOpAndInfer<mlir::tosa::Conv2DOp>(rewriter,
op->getLoc(), newConvOutputType, newSliceInput, newSliceWeight,
newSliceBias, pads, strides, dilations);
newSliceBias, pads, strides, dilations, accType);
// Add value to vector
sliceValues.push_back(tempConv2D);
}
Expand Down Expand Up @@ -147,6 +147,10 @@ class ONNXConvOpLoweringToTOSA : public ConversionPattern {
DenseI64ArrayAttr newPads =
rewriter.getDenseI64ArrayAttr({pads[0], pads[2], pads[1], pads[3]});

Type convType =
(resultType.isF16()) ? rewriter.getF16Type() : rewriter.getF32Type();
TypeAttr accType = mlir::TypeAttr::get(convType);

// Handle group parameter by creating multiple convs
const int64_t group = adaptor.getGroup();
Value conv2D = NULL;
Expand All @@ -157,11 +161,11 @@ class ONNXConvOpLoweringToTOSA : public ConversionPattern {

conv2D = tosa::CreateOpAndInfer<mlir::tosa::Conv2DOp>(rewriter,
convOp->getLoc(), newConvOutputType, newInput, newWeight, bias,
newPads, strides, dilations);
newPads, strides, dilations, accType);
} else {
conv2D = createConvInGroups(rewriter, convOp, tosaBuilder, resultType,
weightShape, newInput, newWeight, bias, group, newPads, strides,
dilations);
dilations, accType);
}

// Convert output [N,OH,OW,OC] -> [N,OC,OH,OW]
Expand Down
8 changes: 3 additions & 5 deletions src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Value buildOnnxToTosaPaddingConstOp(mlir::PatternRewriter &rewriter,

// Create a new pad vec in the right format
// ONNX : [b1, b2, b3, b4, e1, e2, e3, e4]
// TOSA :[[b1, e1], [b2, e2], [b3, e3], [b4, e4]]
// TOSA :[b1, e1, b2, e2, b3, e3, b4, e4]

// Adds any initial or last vals, not included in onnxPads.
llvm::SmallVector<int64_t, 8> tosaPads{initialVals};
Expand All @@ -59,11 +59,9 @@ Value buildOnnxToTosaPaddingConstOp(mlir::PatternRewriter &rewriter,
tosaPads.push_back(onnxPads[i + dimSize]);
}
tosaPads.insert(tosaPads.end(), lastVals.begin(), lastVals.end());

// TOSA format groups dimensions by 2.
const unsigned int numberOfDims = tosaPads.size() / 2;
TosaBuilder tosaBuilder(rewriter, loc);
return tosaBuilder.getConst(tosaPads, {numberOfDims, 2});
return tosaBuilder.getConst(
tosaPads, {static_cast<int64_t>(tosaPads.size())});
}

} // namespace tosa
Expand Down
33 changes: 8 additions & 25 deletions test/mlir/conversion/krnl_to_llvm/krnl_math_function_lowering.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ func.func @test_krnl_erf_lowering(%arg0: memref<10x10xf32>) -> memref<10x10xf32>

// CHECK-LABEL: test_krnl_erf_lowering
// CHECK: [[MEMREF_IN:%.+]] = llvm.insertvalue %arg6, {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[BUILTIN_CAST_0:%.+]] = builtin.unrealized_conversion_cast [[MEMREF_IN]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<10x10xf32>
// CHECK: [[BUILTIN_CAST_1:%.+]] = builtin.unrealized_conversion_cast [[BUILTIN_CAST_0]] : memref<10x10xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA:%.+]] = llvm.extractvalue [[BUILTIN_CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA:%.+]] = llvm.extractvalue [[MEMREF_IN]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA_IN:%.+]] = llvm.getelementptr [[DATA]]{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[SCALAR_IN:%.+]] = llvm.load [[DATA_IN]] : !llvm.ptr
// CHECK: [[ERF_RES:%.+]] = llvm.call @erff([[SCALAR_IN]]) : (f32) -> f32
Expand All @@ -41,9 +39,7 @@ func.func @test_krnl_acos_lowering(%arg0: memref<10x10xf32>) -> memref<10x10xf32

// CHECK-LABEL: test_krnl_acos_lowering
// CHECK: [[MEMREF_IN:%.+]] = llvm.insertvalue %arg6, {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[BUILTIN_CAST_0:%.+]] = builtin.unrealized_conversion_cast [[MEMREF_IN]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<10x10xf32>
// CHECK: [[BUILTIN_CAST_1:%.+]] = builtin.unrealized_conversion_cast [[BUILTIN_CAST_0]] : memref<10x10xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA:%.+]] = llvm.extractvalue [[BUILTIN_CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA:%.+]] = llvm.extractvalue [[MEMREF_IN]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA_IN:%.+]] = llvm.getelementptr [[DATA]]{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[SCALAR_IN:%.+]] = llvm.load [[DATA_IN]] : !llvm.ptr
// CHECK: [[ACOS_RES:%.+]] = llvm.call @acosf([[SCALAR_IN]]) : (f32) -> f32
Expand All @@ -66,9 +62,7 @@ func.func @test_krnl_acosh_lowering(%arg0: memref<10x10xf32>) -> memref<10x10xf3

// CHECK-LABEL: test_krnl_acosh_lowering
// CHECK: [[MEMREF_IN:%.+]] = llvm.insertvalue %arg6, {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[BUILTIN_CAST_0:%.+]] = builtin.unrealized_conversion_cast [[MEMREF_IN]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<10x10xf32>
// CHECK: [[BUILTIN_CAST_1:%.+]] = builtin.unrealized_conversion_cast [[BUILTIN_CAST_0]] : memref<10x10xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA:%.+]] = llvm.extractvalue [[BUILTIN_CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA:%.+]] = llvm.extractvalue [[MEMREF_IN]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA_IN:%.+]] = llvm.getelementptr [[DATA]]{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[SCALAR_IN:%.+]] = llvm.load [[DATA_IN]] : !llvm.ptr
// CHECK: [[ACOS_RES:%.+]] = llvm.call @acoshf([[SCALAR_IN]]) : (f32) -> f32
Expand All @@ -91,9 +85,7 @@ func.func @test_krnl_asin_lowering(%arg0: memref<10x10xf32>) -> memref<10x10xf32

// CHECK-LABEL: test_krnl_asin_lowering
// CHECK: [[MEMREF_IN:%.+]] = llvm.insertvalue %arg6, {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[BUILTIN_CAST_0:%.+]] = builtin.unrealized_conversion_cast [[MEMREF_IN]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<10x10xf32>
// CHECK: [[BUILTIN_CAST_1:%.+]] = builtin.unrealized_conversion_cast [[BUILTIN_CAST_0]] : memref<10x10xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA:%.+]] = llvm.extractvalue [[BUILTIN_CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA:%.+]] = llvm.extractvalue [[MEMREF_IN]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA_IN:%.+]] = llvm.getelementptr [[DATA]]{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[SCALAR_IN:%.+]] = llvm.load [[DATA_IN]] : !llvm.ptr
// CHECK: [[ACOS_RES:%.+]] = llvm.call @asinf([[SCALAR_IN]]) : (f32) -> f32
Expand All @@ -116,9 +108,7 @@ func.func @test_krnl_asinh_lowering(%arg0: memref<10x10xf32>) -> memref<10x10xf3

// CHECK-LABEL: test_krnl_asinh_lowering
// CHECK: [[MEMREF_IN:%.+]] = llvm.insertvalue %arg6, {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[BUILTIN_CAST_0:%.+]] = builtin.unrealized_conversion_cast [[MEMREF_IN]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<10x10xf32>
// CHECK: [[BUILTIN_CAST_1:%.+]] = builtin.unrealized_conversion_cast [[BUILTIN_CAST_0]] : memref<10x10xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA:%.+]] = llvm.extractvalue [[BUILTIN_CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA:%.+]] = llvm.extractvalue [[MEMREF_IN]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA_IN:%.+]] = llvm.getelementptr [[DATA]]{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[SCALAR_IN:%.+]] = llvm.load [[DATA_IN]] : !llvm.ptr
// CHECK: [[ACOS_RES:%.+]] = llvm.call @asinhf([[SCALAR_IN]]) : (f32) -> f32
Expand All @@ -141,9 +131,7 @@ func.func @test_krnl_atan_lowering(%arg0: memref<10x10xf32>) -> memref<10x10xf32

// CHECK-LABEL: test_krnl_atan_lowering
// CHECK: [[MEMREF_IN:%.+]] = llvm.insertvalue %arg6, {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[BUILTIN_CAST_0:%.+]] = builtin.unrealized_conversion_cast [[MEMREF_IN]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<10x10xf32>
// CHECK: [[BUILTIN_CAST_1:%.+]] = builtin.unrealized_conversion_cast [[BUILTIN_CAST_0]] : memref<10x10xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA:%.+]] = llvm.extractvalue [[BUILTIN_CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA:%.+]] = llvm.extractvalue [[MEMREF_IN]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA_IN:%.+]] = llvm.getelementptr [[DATA]]{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[SCALAR_IN:%.+]] = llvm.load [[DATA_IN]] : !llvm.ptr
// CHECK: [[ACOS_RES:%.+]] = llvm.call @atanf([[SCALAR_IN]]) : (f32) -> f32
Expand All @@ -165,9 +153,7 @@ func.func @test_krnl_atanh_lowering(%arg0: memref<10x10xf32>) -> memref<10x10xf3

// CHECK-LABEL: test_krnl_atanh_lowering
// CHECK: [[MEMREF_IN:%.+]] = llvm.insertvalue %arg6, {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[BUILTIN_CAST_0:%.+]] = builtin.unrealized_conversion_cast [[MEMREF_IN]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<10x10xf32>
// CHECK: [[BUILTIN_CAST_1:%.+]] = builtin.unrealized_conversion_cast [[BUILTIN_CAST_0]] : memref<10x10xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA:%.+]] = llvm.extractvalue [[BUILTIN_CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA:%.+]] = llvm.extractvalue [[MEMREF_IN]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA_IN:%.+]] = llvm.getelementptr [[DATA]]{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[SCALAR_IN:%.+]] = llvm.load [[DATA_IN]] : !llvm.ptr
// CHECK: [[ACOS_RES:%.+]] = llvm.call @atanhf([[SCALAR_IN]]) : (f32) -> f32
Expand All @@ -189,12 +175,9 @@ func.func @test_krnl_tan_lowering(%arg0: memref<10x10xf32>) -> memref<10x10xf32>

// CHECK-LABEL: test_krnl_tan_lowering
// CHECK: [[MEMREF_IN:%.+]] = llvm.insertvalue %arg6, {{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[BUILTIN_CAST_0:%.+]] = builtin.unrealized_conversion_cast [[MEMREF_IN]] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> to memref<10x10xf32>
// CHECK: [[BUILTIN_CAST_1:%.+]] = builtin.unrealized_conversion_cast [[BUILTIN_CAST_0]] : memref<10x10xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA:%.+]] = llvm.extractvalue [[BUILTIN_CAST_1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA:%.+]] = llvm.extractvalue [[MEMREF_IN]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: [[DATA_IN:%.+]] = llvm.getelementptr [[DATA]]{{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: [[SCALAR_IN:%.+]] = llvm.load [[DATA_IN]] : !llvm.ptr
// CHECK: [[ACOS_RES:%.+]] = llvm.call @tanf([[SCALAR_IN]]) : (f32) -> f32
// CHECK: [[DATA_OUT:%.+]] = llvm.getelementptr {{.*}} : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: llvm.store [[ACOS_RES]], [[DATA_OUT]] : f32, !llvm.ptr

Loading
Loading