diff --git a/.gitmodules b/.gitmodules index d3b5f516f86b..8b46098d9615 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,6 @@ [submodule "externals/llvm-project"] path = externals/llvm-project - url = https://github.com/iree-org/llvm-project.git + url = https://github.com/llvm/llvm-project.git [submodule "externals/stablehlo"] path = externals/stablehlo url = https://github.com/openxla/stablehlo.git diff --git a/externals/llvm-project b/externals/llvm-project index 8007c56fa699..41f65666f637 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 8007c56fa699040c6c921906a5fbfcd5c9bb0953 +Subproject commit 41f65666f6378bba7266be7c662c70074f04ed75 diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h index 2f5aab5b9364..31070901027d 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h @@ -44,8 +44,8 @@ Torch::ValueTensorType getQTorchTypeFromTorchIntType(Type ty); template Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter, Value &ofItem) { - return rewriter.create(binder.getLoc(), - rewriter.getType(), ofItem); + return Torch::AtenItemOp::create(rewriter, binder.getLoc(), + rewriter.getType(), ofItem); } LogicalResult OnnxLstmExpander(OpBinder binder, diff --git a/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp b/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp index ddcfab78ac8f..e5dee5b4c3bb 100644 --- a/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp +++ b/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp @@ -40,8 +40,8 @@ static LogicalResult getOrCreateGlobalVariableForSeed(OpBuilder &b, } b.setInsertionPointToStart(module.getBody()); - b.create( - UnknownLoc::get(b.getContext()), + ml_program::GlobalOp::create( + b, UnknownLoc::get(b.getContext()), /*sym_name=*/getSeedGobalVarName(), /*type=*/tensorType, /*is_mutable=*/true, @@ -71,25 +71,25 @@ class ConvertGetNextSeedOp : public OpConversionPattern { // Refer to https://en.wikipedia.org/wiki/Linear_congruential_generator. // Get the current seed value. auto tensorType = RankedTensorType::get({}, rewriter.getI64Type()); - Value globalVar = rewriter.create( - loc, tensorType, + Value globalVar = ml_program::GlobalLoadOp::create( + rewriter, loc, tensorType, SymbolRefAttr::get(op->getContext(), getSeedGobalVarName())); - Value currentSeed = rewriter.create(loc, globalVar); + Value currentSeed = tensor::ExtractOp::create(rewriter, loc, globalVar); // The value of multiplier and incrementStep are referenced from // https://en.wikipedia.org/wiki/Linear_congruential_generator for 2^64. - Value multiplier = rewriter.create( - loc, rewriter.getI64IntegerAttr(6364136223846793005)); - Value incrementStep = rewriter.create( - loc, rewriter.getI64IntegerAttr(1442695040888963407)); + Value multiplier = arith::ConstantOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(6364136223846793005)); + Value incrementStep = arith::ConstantOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(1442695040888963407)); // temp = multiplier * currentSeed + incrementStep - Value mul = rewriter.create(loc, currentSeed, multiplier); - Value seed = rewriter.create(loc, mul, incrementStep); + Value mul = arith::MulIOp::create(rewriter, loc, currentSeed, multiplier); + Value seed = arith::AddIOp::create(rewriter, loc, mul, incrementStep); globalVar = - rewriter.create(loc, seed, globalVar, ValueRange()); - rewriter.create( - loc, SymbolRefAttr::get(op->getContext(), getSeedGobalVarName()), - globalVar); + tensor::InsertOp::create(rewriter, loc, seed, globalVar, ValueRange()); + ml_program::GlobalStoreOp::create( + rewriter, loc, + SymbolRefAttr::get(op->getContext(), getSeedGobalVarName()), globalVar); rewriter.replaceOp(op, seed); return success(); } diff --git a/lib/Conversion/TorchOnnxToTorch/ComMicrosoftDomain.cpp b/lib/Conversion/TorchOnnxToTorch/ComMicrosoftDomain.cpp index 760170197154..49b028db3de1 100644 --- a/lib/Conversion/TorchOnnxToTorch/ComMicrosoftDomain.cpp +++ b/lib/Conversion/TorchOnnxToTorch/ComMicrosoftDomain.cpp @@ -46,16 +46,16 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain( "result type bind failure"); } - Value cstInterleaved = rewriter.create( - loc, rewriter.getI64IntegerAttr(interleaved)); - Value cstIsPackedBatching = rewriter.create( - loc, rewriter.getI64IntegerAttr(isPackedBatching)); - Value cstNumHeads = rewriter.create( - loc, rewriter.getI64IntegerAttr(numHeads)); - Value cstRotaryEmbeddingDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(rotaryEmbeddingDim)); - Value cstScale = rewriter.create( - loc, rewriter.getF64FloatAttr(scale)); + Value cstInterleaved = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(interleaved)); + Value cstIsPackedBatching = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(isPackedBatching)); + Value cstNumHeads = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(numHeads)); + Value cstRotaryEmbeddingDim = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(rotaryEmbeddingDim)); + Value cstScale = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(scale)); rewriter.replaceOpWithNewOp( binder.op, resultType, input, positionIds, cosCache, sinCache, @@ -160,18 +160,19 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain( int64_t hiddenSize = queryDims[2]; int64_t headSize = hiddenSize / numHeads; - Value cstBatchSize = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(batchSize)); - Value cstSequenceLength = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(sequenceLength)); - Value cstHiddenSize = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(hiddenSize)); - Value cstHeadSize = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(headSize)); - Value cstNumHeads = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(numHeads)); - Value cstKVNumHeads = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(kvNumHeads)); + Value cstBatchSize = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(batchSize)); + Value cstSequenceLength = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(sequenceLength)); + Value cstHiddenSize = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(hiddenSize)); + Value cstHeadSize = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(headSize)); + Value cstNumHeads = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(numHeads)); + Value cstKVNumHeads = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(kvNumHeads)); // Reshape Query, Key and Value as follows: // Query: (batch_size, sequence_length, hidden_size) @@ -184,14 +185,13 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain( // Reshaping query. SmallVector queryReshapeSizesInt{batchSize, numHeads, sequenceLength, headSize}; - Value queryReshapeSizesList = - rewriter.create( - binder.getLoc(), - Torch::ListType::get(Torch::IntType::get(query.getContext())), - llvm::SmallVector{cstBatchSize, cstNumHeads, - cstSequenceLength, cstHeadSize}); - Value qInput = rewriter.create( - loc, + Value queryReshapeSizesList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(query.getContext())), + llvm::SmallVector{cstBatchSize, cstNumHeads, + cstSequenceLength, cstHeadSize}); + Value qInput = Torch::AtenReshapeOp::create( + rewriter, loc, queryType.getWithSizesAndDtype(queryReshapeSizesInt, queryType.getOptionalDtype()), query, queryReshapeSizesList); @@ -199,15 +199,15 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain( // Reshaping key. SmallVector kvReshapeSizesInt{batchSize, kvNumHeads, sequenceLength, headSize}; - Value kvReshapeSizesList = rewriter.create( - binder.getLoc(), + Value kvReshapeSizesList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get(Torch::IntType::get(query.getContext())), llvm::SmallVector{cstBatchSize, cstKVNumHeads, cstSequenceLength, cstHeadSize}); Torch::ValueTensorType keyType = cast(key.getType()); - Value kInput = rewriter.create( - loc, + Value kInput = Torch::AtenReshapeOp::create( + rewriter, loc, keyType.getWithSizesAndDtype(kvReshapeSizesInt, keyType.getOptionalDtype()), key, kvReshapeSizesList); @@ -215,32 +215,33 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain( // Reshaping value. Torch::ValueTensorType valueType = cast(value.getType()); - Value vInput = rewriter.create( - loc, + Value vInput = Torch::AtenReshapeOp::create( + rewriter, loc, valueType.getWithSizesAndDtype(kvReshapeSizesInt, valueType.getOptionalDtype()), value, kvReshapeSizesList); - Value cstNone = rewriter.create(loc); - Value cstFalse = rewriter.create(loc, false); + Value cstNone = Torch::ConstantNoneOp::create(rewriter, loc); + Value cstFalse = Torch::ConstantBoolOp::create(rewriter, loc, false); Value qRotary = qInput, kRotary = kInput; if (doRotary) { // `totalSequenceLength` is a scalar tensor. - Value scalarTotalSeqLens = rewriter.create( - loc, rewriter.getType(), totalSequenceLength); - Value cstIntOne = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value scalarTotalSeqLens = Torch::AtenItemOp::create( + rewriter, loc, rewriter.getType(), + totalSequenceLength); + Value cstIntOne = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(1)); Type boolTy = rewriter.getType(); - Value condA = rewriter.create( - loc, boolTy, cstSequenceLength, cstIntOne); - Value condB = rewriter.create( - loc, boolTy, cstSequenceLength, scalarTotalSeqLens); + Value condA = Torch::AtenGtIntOp::create( + rewriter, loc, boolTy, cstSequenceLength, cstIntOne); + Value condB = Torch::AtenNeIntOp::create( + rewriter, loc, boolTy, cstSequenceLength, scalarTotalSeqLens); // if (sequence_length > 1 && sequence_length != // total_sequence_length) // is_subsequent_prompt = false; // Subsequent prompt - Value isSubsequentPrompt = rewriter.create( - loc, boolTy, condA, condB); + Value isSubsequentPrompt = Torch::Aten__And__BoolOp::create( + rewriter, loc, boolTy, condA, condB); // Generating position_ids for rotary_embedding as follows: // pos_ids_a = torch.zeros((batch_size, seq_len), dtype=torch.int64) @@ -262,27 +263,29 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain( Torch::ValueTensorType positionIdsType = Torch::ValueTensorType::get( context, positionIdsSizeInt, IntegerType::get(context, 64, IntegerType::Signed)); - Value cstInt64Dtype = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr( - (int)torch_upstream::ScalarType::Long)); - - Value cstInterleaved = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(rotaryInterleaved)); - Value cstIntZero = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(0)); - Value cstFloatOne = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value cstInt64Dtype = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr( + (int)torch_upstream::ScalarType::Long)); + + Value cstInterleaved = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(rotaryInterleaved)); + Value cstIntZero = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value cstFloatOne = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(1.0)); Value positionIdsA, positionIdsB; - Value posIdsSizeList = rewriter.create( - loc, + Value posIdsSizeList = Torch::PrimListConstructOp::create( + rewriter, loc, rewriter.getType( rewriter.getType()), SmallVector{cstBatchSize, cstSequenceLength}); - positionIdsA = rewriter.create( - loc, positionIdsType, /*size=*/posIdsSizeList, + positionIdsA = Torch::AtenZerosOp::create( + rewriter, loc, positionIdsType, /*size=*/posIdsSizeList, /*dtype=*/cstInt64Dtype, /*layout=*/cstNone, /*device=*/cstNone, /*pin_memory=*/cstNone); @@ -290,133 +293,133 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain( // Convert seqlens_k which is a tensor of type si32 to si64. Torch::ValueTensorType seqLensKType = cast(seqlensK.getType()); - seqlensK = rewriter.create( - loc, + seqlensK = Torch::AtenToDtypeOp::create( + rewriter, loc, seqLensKType.getWithSizesAndDtype( std::nullopt, rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true)), seqlensK, cstInt64Dtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/cstNone); - Value totalSeqLens = rewriter.create( - loc, seqlensK.getType(), /*self=*/seqlensK, /*other=*/cstIntOne, + Value totalSeqLens = Torch::AtenAddScalarOp::create( + rewriter, loc, seqlensK.getType(), /*self=*/seqlensK, + /*other=*/cstIntOne, /*alpha=*/cstIntOne); - Value pastSeqLens = rewriter.create( - loc, totalSeqLens.getType(), /*self=*/totalSeqLens, + Value pastSeqLens = Torch::AtenSubScalarOp::create( + rewriter, loc, totalSeqLens.getType(), /*self=*/totalSeqLens, /*other=*/cstSequenceLength, /*alpha=*/cstIntOne); Torch::ValueTensorType initPosIdsType = Torch::ValueTensorType::get( context, {sequenceLength}, IntegerType::get(context, 64, IntegerType::Signed)); - Value initPosIds = rewriter.create( - loc, initPosIdsType, cstSequenceLength, cstInt64Dtype, + Value initPosIds = Torch::AtenArangeOp::create( + rewriter, loc, initPosIdsType, cstSequenceLength, cstInt64Dtype, /*layout=*/cstNone, /*device=*/cstNone, /*pin_memory=*/cstNone); - Value repeatValuesList = rewriter.create( - binder.getLoc(), + Value repeatValuesList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get(Torch::IntType::get(context)), llvm::SmallVector{cstBatchSize, cstIntOne}); - positionIdsB = rewriter.create( - loc, positionIdsType, initPosIds, /*repeats=*/repeatValuesList); - - Value cstIntMinusOne = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); - Value viewSizeList = rewriter.create( - binder.getLoc(), + positionIdsB = Torch::AtenRepeatOp::create( + rewriter, loc, positionIdsType, initPosIds, + /*repeats=*/repeatValuesList); + + Value cstIntMinusOne = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value viewSizeList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get(Torch::IntType::get(context)), llvm::SmallVector{cstIntMinusOne, cstIntOne}); Torch::ValueTensorType seqLensViewType = Torch::ValueTensorType::get( context, llvm::SmallVector{batchSize, 1}, IntegerType::get(context, 64, IntegerType::Signed)); - pastSeqLens = rewriter.create( - loc, seqLensViewType, pastSeqLens, viewSizeList); + pastSeqLens = Torch::AtenViewOp::create( + rewriter, loc, seqLensViewType, pastSeqLens, viewSizeList); - positionIdsB = rewriter.create( - loc, positionIdsType, positionIdsB, pastSeqLens, + positionIdsB = Torch::AtenAddTensorOp::create( + rewriter, loc, positionIdsType, positionIdsB, pastSeqLens, /*alpha=*/cstIntOne); - totalSeqLens = rewriter.create( - loc, seqLensViewType, totalSeqLens, viewSizeList); - Value cond = rewriter.create( - loc, + totalSeqLens = Torch::AtenViewOp::create( + rewriter, loc, seqLensViewType, totalSeqLens, viewSizeList); + Value cond = Torch::AtenLtTensorOp::create( + rewriter, loc, positionIdsType.getWithSizesAndDtype(positionIdsType.getSizes(), rewriter.getI1Type()), positionIdsB, totalSeqLens); - Value cstOneTensorDataList = - rewriter.create( - loc, - rewriter.getType( - rewriter.getType()), - SmallVector{cstIntOne}); - Value cstOneTensor = rewriter.create( - loc, + Value cstOneTensorDataList = Torch::PrimListConstructOp::create( + rewriter, loc, + rewriter.getType( + rewriter.getType()), + SmallVector{cstIntOne}); + Value cstOneTensor = Torch::AtenTensorOp::create( + rewriter, loc, Torch::ValueTensorType::get( context, {}, IntegerType::get(context, 64, IntegerType::Signed)), cstOneTensorDataList, /*dtype=*/cstInt64Dtype, /*layout=*/cstNone, /*requires_grad=*/cstFalse); - positionIdsB = rewriter.create( - loc, positionIdsType, cond, positionIdsB, cstOneTensor); + positionIdsB = Torch::AtenWhereSelfOp::create( + rewriter, loc, positionIdsType, cond, positionIdsB, cstOneTensor); - isSubsequentPrompt = rewriter.create( - loc, rewriter.getType(), isSubsequentPrompt); - isSubsequentPrompt = rewriter.create( - loc, + isSubsequentPrompt = Torch::AtenIntBoolOp::create( + rewriter, loc, rewriter.getType(), + isSubsequentPrompt); + isSubsequentPrompt = Torch::AtenFullOp::create( + rewriter, loc, Torch::ValueTensorType::get(context, positionIdsSizeInt, rewriter.getI1Type()), /*size=*/posIdsSizeList, /*fill_value=*/isSubsequentPrompt, /*dtype=*/ - rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr( - (int)torch_upstream::ScalarType::Bool)), + Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr( + (int)torch_upstream::ScalarType::Bool)), /*layout=*/cstNone, /*device=*/cstNone, /*pin_memory=*/cstNone); - Value positionIds = rewriter.create( - loc, positionIdsType, isSubsequentPrompt, positionIdsB, + Value positionIds = Torch::AtenWhereSelfOp::create( + rewriter, loc, positionIdsType, isSubsequentPrompt, positionIdsB, positionIdsA); // Performing RotaryEmbedding over Query and Key. - qRotary = rewriter.create( - loc, qInput.getType(), qInput, positionIds, cosCache, sinCache, - cstInterleaved, /*is_packed_batching=*/cstIntZero, + qRotary = Torch::OnnxVariantRotaryEmbeddingOp::create( + rewriter, loc, qInput.getType(), qInput, positionIds, cosCache, + sinCache, cstInterleaved, /*is_packed_batching=*/cstIntZero, /*num_heads=*/cstIntZero, /*rotary_embedding_dim=*/cstIntZero, /*scale=*/cstFloatOne); - kRotary = rewriter.create( - loc, qInput.getType(), kInput, positionIds, cosCache, sinCache, - cstInterleaved, /*is_packed_batching=*/cstIntZero, + kRotary = Torch::OnnxVariantRotaryEmbeddingOp::create( + rewriter, loc, qInput.getType(), kInput, positionIds, cosCache, + sinCache, cstInterleaved, /*is_packed_batching=*/cstIntZero, /*num_heads=*/cstIntZero, /*rotary_embedding_dim=*/cstIntZero, /*scale=*/cstFloatOne); } // Do attention. - Value cstEnableGQA = rewriter.create(loc, true); - Value cstFloatZero = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value cstEnableGQA = Torch::ConstantBoolOp::create(rewriter, loc, true); + Value cstFloatZero = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(0.0)); Value cstScale = cstNone; if (scale != 0.0f) - cstScale = rewriter.create( - binder.getLoc(), rewriter.getType(), + cstScale = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(scale)); - Value attention = - rewriter.create( - loc, qRotary.getType(), qRotary, kRotary, vInput, - /*attn_mask=*/cstNone, - /*dropout_p=*/cstFloatZero, /*is_causal=*/cstFalse, cstScale, - cstEnableGQA); + Value attention = Torch::AtenScaledDotProductAttentionOp::create( + rewriter, loc, qRotary.getType(), qRotary, kRotary, vInput, + /*attn_mask=*/cstNone, + /*dropout_p=*/cstFloatZero, /*is_causal=*/cstFalse, cstScale, + cstEnableGQA); // Reshaping the attention result from: // (batch_size, num_heads, sequence_length, head_size) // -> (batch_size, sequence_length, hidden_size) - Value attentionResultSizesList = - rewriter.create( - binder.getLoc(), - Torch::ListType::get( - Torch::IntType::get(attention.getContext())), - llvm::SmallVector{cstBatchSize, cstSequenceLength, - cstHiddenSize}); - attention = rewriter.create( - loc, resultTypes[0], attention, attentionResultSizesList); + Value attentionResultSizesList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(attention.getContext())), + llvm::SmallVector{cstBatchSize, cstSequenceLength, + cstHiddenSize}); + attention = Torch::AtenReshapeOp::create( + rewriter, loc, resultTypes[0], attention, attentionResultSizesList); // Compute 2nd and 3rd result: present_key, present_value. // present_key = torch.cat([past_key, key], dim=2) or past_key @@ -425,31 +428,31 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain( if (!llvm::equal( cast(pastKey.getType()).getSizes(), cast(resultTypes[1]).getSizes())) { - Value cstConcatDim = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(2)); + Value cstConcatDim = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(2)); Type kvListElemType = keyType.getWithSizesAndDtype( /*optionalSizes=*/std::nullopt, /*optionalDtype=*/nullptr); Type kvListType = Torch::ListType::get(kvListElemType); - Value keyList = rewriter.create( - loc, kvListType, SmallVector{pastKey, kRotary}); - presentKey = rewriter.create(loc, resultTypes[1], - keyList, cstConcatDim); + Value keyList = Torch::PrimListConstructOp::create( + rewriter, loc, kvListType, SmallVector{pastKey, kRotary}); + presentKey = Torch::AtenCatOp::create(rewriter, loc, resultTypes[1], + keyList, cstConcatDim); } if (!llvm::equal( cast(pastValue.getType()).getSizes(), cast(resultTypes[2]).getSizes())) { - Value cstConcatDim = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(2)); + Value cstConcatDim = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(2)); Type kvListElemType = keyType.getWithSizesAndDtype( /*optionalSizes=*/std::nullopt, /*optionalDtype=*/nullptr); Type kvListType = Torch::ListType::get(kvListElemType); - Value valueList = rewriter.create( - loc, kvListType, SmallVector{pastValue, vInput}); - presentValue = rewriter.create( - loc, resultTypes[2], valueList, cstConcatDim); + Value valueList = Torch::PrimListConstructOp::create( + rewriter, loc, kvListType, SmallVector{pastValue, vInput}); + presentValue = Torch::AtenCatOp::create(rewriter, loc, resultTypes[2], + valueList, cstConcatDim); } rewriter.replaceOp(binder.op, {attention, presentKey, presentValue}); @@ -506,22 +509,22 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain( // Computing the result of "Add". auto cTy = rewriter.getType( resultType.getOptionalSizes(), rewriter.getF32Type()); - Value alpha = rewriter.create( - loc, rewriter.getF64FloatAttr(1.0)); - Value c = rewriter.create(binder.getLoc(), cTy, - a, b, alpha); + Value alpha = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(1.0)); + Value c = Torch::AtenAddTensorOp::create(rewriter, binder.getLoc(), cTy, + a, b, alpha); // Quantizing the result of "Add" operation. cTy = dyn_cast( getQTorchTypeFromTorchIntType(resultType)); - Value dtyVal = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value dtyVal = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr( rewriter.getIntegerType(64), static_cast( Torch::getScalarTypeForType(cTy.getDtype())))); - c = rewriter.create( - binder.getLoc(), cTy, c, cScale, cZp, dtyVal); + c = Torch::AtenQuantizePerTensorOp::create(rewriter, binder.getLoc(), + cTy, c, cScale, cZp, dtyVal); rewriter.replaceOpWithNewOp(binder.op, resultType, c); return success(); @@ -564,25 +567,25 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain( "missing sizes"); // Computing the LeakyRelu result. - Value constAlpha = rewriter.create( - loc, rewriter.getType(), + Value constAlpha = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getType(), rewriter.getF64FloatAttr((double)alpha)); auto yTy = rewriter.getType( resultType.getOptionalSizes(), rewriter.getF32Type()); Value y = - rewriter.create(loc, yTy, x, constAlpha); + Torch::AtenLeakyReluOp::create(rewriter, loc, yTy, x, constAlpha); // Quantizing the result of LeakyRelu op. yTy = dyn_cast( getQTorchTypeFromTorchIntType(resultType)); - Value dtyVal = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value dtyVal = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr( rewriter.getIntegerType(64), static_cast( Torch::getScalarTypeForType(yTy.getDtype())))); - y = rewriter.create(loc, yTy, y, yScale, - yZp, dtyVal); + y = Torch::AtenQuantizePerTensorOp::create(rewriter, loc, yTy, y, + yScale, yZp, dtyVal); rewriter.replaceOpWithNewOp(binder.op, resultType, y); return success(); @@ -642,14 +645,14 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain( .getWithSizesAndDtype(/*optionalSizes=*/std::nullopt, /*optionalDtype=*/nullptr); Type listType = Torch::ListType::get(listElemType); - Value tensorList = rewriter.create( - binder.op->getLoc(), listType, dequantizedInputs); - Value cstAxis = rewriter.create( - loc, rewriter.getI64IntegerAttr(axis)); + Value tensorList = Torch::PrimListConstructOp::create( + rewriter, binder.op->getLoc(), listType, dequantizedInputs); + Value cstAxis = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(axis)); auto concatTy = rewriter.getType( resultType.getOptionalSizes(), rewriter.getF32Type()); - Value concat = rewriter.create(loc, concatTy, - tensorList, cstAxis); + Value concat = Torch::AtenCatOp::create(rewriter, loc, concatTy, + tensorList, cstAxis); // Quantizing the result of concatenated inputs. Value yScale, yZp; @@ -661,14 +664,14 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain( "per-tensor quantization"); Torch::ValueTensorType yTy = dyn_cast( getQTorchTypeFromTorchIntType(resultType)); - Value dtyVal = rewriter.create( - loc, rewriter.getType(), + Value dtyVal = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getType(), rewriter.getIntegerAttr( rewriter.getIntegerType(64), static_cast( Torch::getScalarTypeForType(yTy.getDtype())))); - Value result = rewriter.create( - loc, yTy, concat, yScale, yZp, dtyVal); + Value result = Torch::AtenQuantizePerTensorOp::create( + rewriter, loc, yTy, concat, yScale, yZp, dtyVal); rewriter.replaceOpWithNewOp(binder.op, resultType, result); return success(); @@ -726,58 +729,58 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain( // Computing the AvgPool result. SmallVector cstKernel, cstPadding, cstStrides; - Value cstZero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - Value cstOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); + Value cstZero = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(0)); + Value cstOne = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(1)); unsigned inputRank = inputShape.size(); for (unsigned i = 2; i < inputRank; i++) { if (inputShape[i] == Torch::kUnknownSize) { - Value dim = rewriter.create( - loc, rewriter.getI64IntegerAttr(i)); + Value dim = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(i)); Value inputDimSize = - rewriter.create(loc, x, dim); + Torch::AtenSizeIntOp::create(rewriter, loc, x, dim); cstKernel.push_back(inputDimSize); } else { int64_t kernelSize = inputShape[i] - resultShape[i] + 1; - cstKernel.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(kernelSize))); + cstKernel.push_back(Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(kernelSize))); } cstPadding.push_back(cstZero); cstStrides.push_back(cstOne); } - Value kernelSizeList = rewriter.create( - loc, + Value kernelSizeList = Torch::PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstKernel); - Value paddingList = rewriter.create( - loc, + Value paddingList = Torch::PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstPadding); - Value stridesList = rewriter.create( - loc, + Value stridesList = Torch::PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstStrides); - Value cstFalse = rewriter.create(loc, false); + Value cstFalse = Torch::ConstantBoolOp::create(rewriter, loc, false); Value cstCeilMode = cstFalse; Value cstCountIncludePad = cstFalse; - Value cstNone = rewriter.create(loc); + Value cstNone = Torch::ConstantNoneOp::create(rewriter, loc); auto yTy = rewriter.getType( resultShape, rewriter.getF32Type()); Value avgpool; if (inputRank == 3) { - avgpool = rewriter.create( - loc, yTy, x, kernelSizeList, stridesList, paddingList, + avgpool = Torch::AtenAvgPool1dOp::create( + rewriter, loc, yTy, x, kernelSizeList, stridesList, paddingList, cstCeilMode, cstCountIncludePad); } else if (inputRank == 4) { - avgpool = rewriter.create( - loc, yTy, x, kernelSizeList, stridesList, paddingList, + avgpool = Torch::AtenAvgPool2dOp::create( + rewriter, loc, yTy, x, kernelSizeList, stridesList, paddingList, cstCeilMode, cstCountIncludePad, /*divisor_override=*/cstNone); } else if (inputRank == 5) { - avgpool = rewriter.create( - loc, yTy, x, kernelSizeList, stridesList, paddingList, + avgpool = Torch::AtenAvgPool3dOp::create( + rewriter, loc, yTy, x, kernelSizeList, stridesList, paddingList, cstCeilMode, cstCountIncludePad, /*divisor_override=*/cstNone); } else { @@ -787,14 +790,14 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain( // Quantizing the result of AvgPool op. yTy = dyn_cast( getQTorchTypeFromTorchIntType(resultType)); - Value dtyVal = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value dtyVal = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr( rewriter.getIntegerType(64), static_cast( Torch::getScalarTypeForType(yTy.getDtype())))); - avgpool = rewriter.create( - loc, yTy, avgpool, yScale, yZp, dtyVal); + avgpool = Torch::AtenQuantizePerTensorOp::create( + rewriter, loc, yTy, avgpool, yScale, yZp, dtyVal); rewriter.replaceOpWithNewOp(binder.op, resultType, avgpool); return success(); @@ -837,19 +840,19 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain( // Computing the Sigmoid result. auto yTy = rewriter.getType( resultType.getOptionalSizes(), rewriter.getF32Type()); - Value y = rewriter.create(loc, yTy, x); + Value y = Torch::AtenSigmoidOp::create(rewriter, loc, yTy, x); // Quantizing the result of Sigmoid op. yTy = dyn_cast( getQTorchTypeFromTorchIntType(resultType)); - Value dtyVal = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value dtyVal = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr( rewriter.getIntegerType(64), static_cast( Torch::getScalarTypeForType(yTy.getDtype())))); - y = rewriter.create(loc, yTy, y, yScale, - yZp, dtyVal); + y = Torch::AtenQuantizePerTensorOp::create(rewriter, loc, yTy, y, + yScale, yZp, dtyVal); rewriter.replaceOpWithNewOp(binder.op, resultType, y); return success(); @@ -910,24 +913,22 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain( auto yTy = rewriter.getType( resultType.getOptionalSizes(), rewriter.getF32Type()); - Value averagePool = - rewriter - .create(binder.getLoc(), yTy, newOperands, - newAttributes, - binder.op->getRegions().size()) - .getResult(0); + Value averagePool = Torch::OperatorOp::create( + rewriter, binder.getLoc(), yTy, newOperands, + newAttributes, binder.op->getRegions().size()) + .getResult(0); // Quantizing the result of AveragePool op. yTy = dyn_cast( getQTorchTypeFromTorchIntType(resultType)); - Value dtyVal = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value dtyVal = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr( rewriter.getIntegerType(64), static_cast( Torch::getScalarTypeForType(yTy.getDtype())))); - averagePool = rewriter.create( - loc, yTy, averagePool, yScale, yZp, dtyVal); + averagePool = Torch::AtenQuantizePerTensorOp::create( + rewriter, loc, yTy, averagePool, yScale, yZp, dtyVal); rewriter.replaceOpWithNewOp(binder.op, resultType, averagePool); return success(); @@ -1039,20 +1040,20 @@ void mlir::torch::onnx_c::populateComMicrosoftDomain( // Computing the Mul result. auto cTy = rewriter.getType( resultType.getOptionalSizes(), rewriter.getF32Type()); - Value c = - rewriter.create(binder.getLoc(), cTy, a, b); + Value c = Torch::AtenMulTensorOp::create(rewriter, binder.getLoc(), cTy, + a, b); // Quantizing the result of Mul operation. cTy = dyn_cast( getQTorchTypeFromTorchIntType(resultType)); - Value dtyVal = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value dtyVal = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr( rewriter.getIntegerType(64), static_cast( Torch::getScalarTypeForType(cTy.getDtype())))); - c = rewriter.create( - binder.getLoc(), cTy, c, cScale, cZp, dtyVal); + c = Torch::AtenQuantizePerTensorOp::create(rewriter, binder.getLoc(), + cTy, c, cScale, cZp, dtyVal); rewriter.replaceOpWithNewOp(binder.op, resultType, c); return success(); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 8f14515e425c..eff45ee3d98a 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -29,18 +29,18 @@ LogicalResult windowFunctionImpl(OpBinder binder, double isPeriodicFp = static_cast(periodic); - Value zero = b.create(rewriter.getF64FloatAttr(0.0)); - Value one = b.create(rewriter.getF64FloatAttr(1.0)); - Value two = b.create(rewriter.getF64FloatAttr(2.0)); + Value zero = Torch::ConstantFloatOp::create(b, rewriter.getF64FloatAttr(0.0)); + Value one = Torch::ConstantFloatOp::create(b, rewriter.getF64FloatAttr(1.0)); + Value two = Torch::ConstantFloatOp::create(b, rewriter.getF64FloatAttr(2.0)); constexpr double pi = llvm::numbers::pi; - Value tau = b.create( - rewriter.getFloatAttr(rewriter.getF64Type(), 2.0 * pi)); + Value tau = Torch::ConstantFloatOp::create( + b, rewriter.getFloatAttr(rewriter.getF64Type(), 2.0 * pi)); - Value noneVal = b.create(); - Value cstFalse = b.create(false); - Value float32Type = b.create( - rewriter.getI64IntegerAttr(/*float32Type*/ 6)); + Value noneVal = Torch::ConstantNoneOp::create(b); + Value cstFalse = Torch::ConstantBoolOp::create(b, false); + Value float32Type = Torch::ConstantIntOp::create( + b, rewriter.getI64IntegerAttr(/*float32Type*/ 6)); // Create an f32 ValueTensorType with thse same size as size, the // operand @@ -48,45 +48,47 @@ LogicalResult windowFunctionImpl(OpBinder binder, dyn_cast(size.getType()).getOptionalSizes(); auto f32ResultType = rewriter.getType( shapeOfOperand, rewriter.getF32Type()); - Value periodicSizeFloat = b.create( - f32ResultType, size, float32Type, cstFalse, cstFalse, noneVal); - Value symmetricSizeFloat = b.create( - periodicSizeFloat.getType(), periodicSizeFloat, one, one); + Value periodicSizeFloat = Torch::AtenToDtypeOp::create( + b, f32ResultType, size, float32Type, cstFalse, cstFalse, noneVal); + Value symmetricSizeFloat = Torch::AtenSubScalarOp::create( + b, periodicSizeFloat.getType(), periodicSizeFloat, one, one); Value isPeriodic = - b.create(rewriter.getF64FloatAttr(isPeriodicFp)); - Value isSymmetricFloat = b.create( - rewriter.getF64FloatAttr(1.0 - isPeriodicFp)); - - Value periodicComponent = b.create( - periodicSizeFloat.getType(), periodicSizeFloat, isPeriodic); - Value symmetricComponent = b.create( - symmetricSizeFloat.getType(), symmetricSizeFloat, isSymmetricFloat); - Value sizeFloat = b.create( - symmetricComponent.getType(), symmetricComponent, periodicComponent, one); + Torch::ConstantFloatOp::create(b, rewriter.getF64FloatAttr(isPeriodicFp)); + Value isSymmetricFloat = Torch::ConstantFloatOp::create( + b, rewriter.getF64FloatAttr(1.0 - isPeriodicFp)); + + Value periodicComponent = Torch::AtenMulScalarOp::create( + b, periodicSizeFloat.getType(), periodicSizeFloat, isPeriodic); + Value symmetricComponent = Torch::AtenMulScalarOp::create( + b, symmetricSizeFloat.getType(), symmetricSizeFloat, isSymmetricFloat); + Value sizeFloat = Torch::AtenAddTensorOp::create( + b, symmetricComponent.getType(), symmetricComponent, periodicComponent, + one); // Here, size can be used in the place of periodicSizeFloat, as the // latter is just a float representation of the former. Value scalarLimit = getItemOp(binder, rewriter, size); - Value rangeArr = b.create( - resultType, zero, scalarLimit, one, noneVal, noneVal, noneVal, noneVal); + Value rangeArr = Torch::AtenArangeStartStepOp::create( + b, resultType, zero, scalarLimit, one, noneVal, noneVal, noneVal, + noneVal); Value rangeTimesTau = - b.create(resultType, rangeArr, tau); + Torch::AtenMulScalarOp::create(b, resultType, rangeArr, tau); Value rangeAngular = - b.create(resultType, rangeTimesTau, sizeFloat); + Torch::AtenDivTensorOp::create(b, resultType, rangeTimesTau, sizeFloat); Value twoRangeAngular = - b.create(resultType, rangeAngular, two); + Torch::AtenMulScalarOp::create(b, resultType, rangeAngular, two); - Value cosRangeAngular = b.create(resultType, rangeAngular); + Value cosRangeAngular = Torch::AtenCosOp::create(b, resultType, rangeAngular); Value cosTwoRangeAngular = - b.create(resultType, twoRangeAngular); + Torch::AtenCosOp::create(b, resultType, twoRangeAngular); Value a1Component = - b.create(resultType, cosRangeAngular, a1); + Torch::AtenMulScalarOp::create(b, resultType, cosRangeAngular, a1); Value a2Component = - b.create(resultType, cosTwoRangeAngular, a2); + Torch::AtenMulScalarOp::create(b, resultType, cosTwoRangeAngular, a2); // AtenSubScalarOp actually requires a tensor operand as the LHS, that // is, operand #1. Therefore, to avoid errors, the onnx implementation @@ -94,9 +96,9 @@ LogicalResult windowFunctionImpl(OpBinder binder, // AtenSubScalarOp has been replaced with AtenAddScalarOp, as the add // operation is commutative. Value subA1Component = - b.create(resultType, a1Component, a0, one); - Value result = b.create(resultType, subA1Component, - a2Component, one); + Torch::AtenAddScalarOp::create(b, resultType, a1Component, a0, one); + Value result = Torch::AtenAddTensorOp::create(b, resultType, subA1Component, + a2Component, one); std::optional dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(output_datatype); @@ -104,8 +106,8 @@ LogicalResult windowFunctionImpl(OpBinder binder, return rewriter.notifyMatchFailure( binder.op, "unimplemented support for the given dtype conversion"); } - Value outputDtype = b.create( - rewriter.getType(), + Value outputDtype = Torch::ConstantIntOp::create( + b, rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), dtypeIntTorch.value())); @@ -145,20 +147,20 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return success(); }); // Add became forward compatible with Torch in version 7. - patterns.onOp("Add", 7, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value lhs, rhs; - if (binder.tensorOperands(lhs, rhs) || - binder.tensorResultType(resultType)) - return failure(); - Value const1 = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); - rewriter.replaceOpWithNewOp( - binder.op, resultType, lhs, rhs, const1); - return success(); - }); + patterns.onOp( + "Add", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) + return failure(); + Value const1 = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs, const1); + return success(); + }); // TODO: AffineGrid patterns.onOp("And", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { @@ -191,27 +193,28 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( if (axis < 0) axis += operandSizes.size(); - Value constAxis = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value constAxis = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis)); - Value constKeepDims = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value constKeepDims = Torch::ConstantBoolOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getBoolAttr(keepDims)); if (selectLastIndex) { Value dims = createConstantIntList(binder, rewriter, {axis}); auto operandTy = dyn_cast(operand.getType()); - operand = rewriter.create( - binder.getLoc(), operandTy, operand, dims); - Value argmax = rewriter.create( - binder.getLoc(), resultType, operand, constAxis, constKeepDims); - Value offset = rewriter.create( - binder.getLoc(), + operand = Torch::AtenFlipOp::create(rewriter, binder.getLoc(), + operandTy, operand, dims); + Value argmax = + Torch::AtenArgmaxOp::create(rewriter, binder.getLoc(), resultType, + operand, constAxis, constKeepDims); + Value offset = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(operandSizes[axis] - 1)); - Value alpha = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); - Value sub = rewriter.create( - binder.getLoc(), resultType, argmax, offset, alpha); + Value alpha = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value sub = Torch::AtenSubScalarOp::create( + rewriter, binder.getLoc(), resultType, argmax, offset, alpha); rewriter.replaceOpWithNewOp(binder.op, resultType, sub); return success(); @@ -241,27 +244,28 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( if (axis < 0) axis += operandSizes.size(); - Value constAxis = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value constAxis = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis)); - Value constKeepDims = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value constKeepDims = Torch::ConstantBoolOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getBoolAttr(keepDims)); if (selectLastIndex) { Value dims = createConstantIntList(binder, rewriter, {axis}); auto operandTy = dyn_cast(operand.getType()); - operand = rewriter.create( - binder.getLoc(), operandTy, operand, dims); - Value argmin = rewriter.create( - binder.getLoc(), resultType, operand, constAxis, constKeepDims); - Value offset = rewriter.create( - binder.getLoc(), + operand = Torch::AtenFlipOp::create(rewriter, binder.getLoc(), + operandTy, operand, dims); + Value argmin = + Torch::AtenArgminOp::create(rewriter, binder.getLoc(), resultType, + operand, constAxis, constKeepDims); + Value offset = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(operandSizes[axis] - 1)); - Value alpha = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); - Value sub = rewriter.create( - binder.getLoc(), resultType, argmin, offset, alpha); + Value alpha = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value sub = Torch::AtenSubScalarOp::create( + rewriter, binder.getLoc(), resultType, argmin, offset, alpha); rewriter.replaceOpWithNewOp(binder.op, resultType, sub); return success(); @@ -356,11 +360,11 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return failure(); Location loc = binder.getLoc(); - Value cstFalse = rewriter.create(loc, false); - Value cstMomentum = rewriter.create( - loc, rewriter.getF64FloatAttr(momentum)); - Value cstEps = rewriter.create( - loc, rewriter.getF64FloatAttr(eps)); + Value cstFalse = Torch::ConstantBoolOp::create(rewriter, loc, false); + Value cstMomentum = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(momentum)); + Value cstEps = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(eps)); // When training_mode=False, the op outputs only Y, where // Y = (X - input_mean) / sqrt(input_var + epsilon) * scale + @@ -402,51 +406,54 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( SmallVector dimsToReduce; for (int64_t i = 0; i < inputRank; i++) { if (i != 1) - dimsToReduce.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i))); + dimsToReduce.push_back(Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(i))); } - Value reduceDimsList = rewriter.create( - binder.getLoc(), + Value reduceDimsList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), dimsToReduce); - Value noneVal = rewriter.create(binder.getLoc()); - Value currentMean = rewriter.create( - loc, meanResultType, input, reduceDimsList, + Value noneVal = + Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); + Value currentMean = Torch::AtenMeanDimOp::create( + rewriter, loc, meanResultType, input, reduceDimsList, /*keepdim=*/cstFalse, /*dtype=*/noneVal); - Value currentVar = rewriter.create( - loc, varResultType, input, reduceDimsList, + Value currentVar = Torch::AtenVarDimOp::create( + rewriter, loc, varResultType, input, reduceDimsList, /*unbiased=*/cstFalse, /*keepdim=*/cstFalse); // Computing running_mean. - Value inputMeanMulMomentum = rewriter.create( - loc, meanResultType, inputMean, cstMomentum); - Value currentMeanMulMomentum = rewriter.create( - loc, varResultType, currentMean, cstMomentum); - Value constantOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - Value inpMeanMMSubCurMeanMM = rewriter.create( - loc, meanResultType, inputMeanMulMomentum, currentMeanMulMomentum, - constantOne); - Value runningMean = rewriter.create( - loc, meanResultType, inpMeanMMSubCurMeanMM, currentMean, + Value inputMeanMulMomentum = Torch::AtenMulScalarOp::create( + rewriter, loc, meanResultType, inputMean, cstMomentum); + Value currentMeanMulMomentum = Torch::AtenMulScalarOp::create( + rewriter, loc, varResultType, currentMean, cstMomentum); + Value constantOne = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(1)); + Value inpMeanMMSubCurMeanMM = Torch::AtenSubTensorOp::create( + rewriter, loc, meanResultType, inputMeanMulMomentum, + currentMeanMulMomentum, constantOne); + Value runningMean = Torch::AtenAddTensorOp::create( + rewriter, loc, meanResultType, inpMeanMMSubCurMeanMM, currentMean, constantOne); // Computing running_var. - Value inputVarMulMomentum = rewriter.create( - loc, varResultType, inputVar, cstMomentum); - Value currentVarMulMomentum = rewriter.create( - loc, varResultType, currentVar, cstMomentum); - Value inpVarMMSubCurVarMM = rewriter.create( - loc, varResultType, inputVarMulMomentum, currentVarMulMomentum, + Value inputVarMulMomentum = Torch::AtenMulScalarOp::create( + rewriter, loc, varResultType, inputVar, cstMomentum); + Value currentVarMulMomentum = Torch::AtenMulScalarOp::create( + rewriter, loc, varResultType, currentVar, cstMomentum); + Value inpVarMMSubCurVarMM = Torch::AtenSubTensorOp::create( + rewriter, loc, varResultType, inputVarMulMomentum, + currentVarMulMomentum, constantOne); + Value runningVar = Torch::AtenAddTensorOp::create( + rewriter, loc, varResultType, inpVarMMSubCurVarMM, currentVar, constantOne); - Value runningVar = rewriter.create( - loc, varResultType, inpVarMMSubCurVarMM, currentVar, constantOne); // Computing Y. - Value y = rewriter.create( - loc, resultType, input, weight, bias, currentMean, currentVar, + Value y = Torch::AtenBatchNormOp::create( + rewriter, loc, resultType, input, weight, bias, currentMean, + currentVar, /*training=*/cstFalse, cstMomentum, cstEps, /*cudnn_enabled=*/cstFalse); @@ -564,10 +571,11 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( Value stridesDilationsList = createConstantIntList(binder, rewriter, stridesDilations); Value cstCeilMode = - rewriter.create(binder.getLoc(), ceilMode); - Value cstCountIncludePad = rewriter.create( - binder.getLoc(), countIncludePad); - Value cstNone = rewriter.create(binder.getLoc()); + Torch::ConstantBoolOp::create(rewriter, binder.getLoc(), ceilMode); + Value cstCountIncludePad = Torch::ConstantBoolOp::create( + rewriter, binder.getLoc(), countIncludePad); + Value cstNone = + Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); if (rank == 3) { rewriter.replaceOpWithNewOp( @@ -612,9 +620,10 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( "unimplemented: support not present for seed attribute"); } - Value none = rewriter.create(binder.getLoc()); - Value bernoulli = rewriter.create( - binder.getLoc(), input.getType(), input, /*generator=*/none); + Value none = Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); + Value bernoulli = Torch::AtenBernoulliOp::create( + rewriter, binder.getLoc(), input.getType(), input, + /*generator=*/none); if (dtypeIntOnnx == -1) { // True, if dtype attribute value is not present. @@ -628,10 +637,11 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, "unimplemented support for the given dtype conversion"); } - Value constDtype = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value())); + Value constDtype = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(dtypeIntTorch.value())); Value cstFalse = - rewriter.create(binder.getLoc(), false); + Torch::ConstantBoolOp::create(rewriter, binder.getLoc(), false); rewriter.replaceOpWithNewOp( binder.op, resultType, bernoulli, constDtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, @@ -720,11 +730,12 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, "unimplemented support for the given dtype conversion"); } - Value constDtype = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value())); - Value none = rewriter.create(binder.getLoc()); + Value constDtype = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(dtypeIntTorch.value())); + Value none = Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); Value cstFalse = - rewriter.create(binder.getLoc(), false); + Torch::ConstantBoolOp::create(rewriter, binder.getLoc(), false); rewriter.replaceOpWithNewOp( binder.op, resultType, operand, constDtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, @@ -752,9 +763,9 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( Type targetDtype = targetTy.getDtype(); Value constDtype = Torch::getDtypeIntValueForType( rewriter, binder.getLoc(), targetDtype); - Value none = rewriter.create(binder.getLoc()); + Value none = Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); Value cstFalse = - rewriter.create(binder.getLoc(), false); + Torch::ConstantBoolOp::create(rewriter, binder.getLoc(), false); rewriter.replaceOpWithNewOp( binder.op, resultType, input, constDtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, @@ -782,32 +793,32 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.f32FloatAttr(alpha, "alpha", 1.0f)) return failure(); // exp(x/alpha) - Value constAlpha = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value constAlpha = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(alpha)); - Value xDivAlpha = rewriter.create( - binder.getLoc(), resultType, operand, constAlpha); - Value expXDivAlpha = rewriter.create( - binder.getLoc(), resultType, xDivAlpha); + Value xDivAlpha = Torch::AtenDivScalarOp::create( + rewriter, binder.getLoc(), resultType, operand, constAlpha); + Value expXDivAlpha = Torch::AtenExpOp::create(rewriter, binder.getLoc(), + resultType, xDivAlpha); // alpha * (exp(x/alpha) - 1) - Value constantOne = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); - Value subOne = rewriter.create( - binder.getLoc(), resultType, expXDivAlpha, constantOne, - constantOne); - Value mulAlpha = rewriter.create( - binder.getLoc(), resultType, subOne, constAlpha); - Value constantZero = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value constantOne = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value subOne = Torch::AtenSubScalarOp::create(rewriter, binder.getLoc(), + resultType, expXDivAlpha, + constantOne, constantOne); + Value mulAlpha = Torch::AtenMulScalarOp::create( + rewriter, binder.getLoc(), resultType, subOne, constAlpha); + Value constantZero = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(0)); Value zeroTensor = createRank0Tensor(rewriter, binder.getLoc(), resultType, constantZero); // min(0, alpha * (exp(x/alpha) - 1)) - Value minExpression = rewriter.create( - binder.getLoc(), resultType, zeroTensor, mulAlpha); + Value minExpression = Torch::AtenMinimumOp::create( + rewriter, binder.getLoc(), resultType, zeroTensor, mulAlpha); // max(0, x) - Value maxExpression = rewriter.create( - binder.getLoc(), resultType, zeroTensor, operand); + Value maxExpression = Torch::AtenMaximumOp::create( + rewriter, binder.getLoc(), resultType, zeroTensor, operand); // max(0,x) + min(0, alpha * (exp(x/alpha) - 1)) rewriter.replaceOpWithNewOp( binder.op, resultType, maxExpression, minExpression, constantOne); @@ -834,13 +845,13 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( } int64_t axesSize = axes.size(); - Value none = rewriter.create(binder.getLoc()); - Value cstZero = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(0)); - Value cstOne = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); - Value cstTwo = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(2)); + Value none = Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); + Value cstZero = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value cstOne = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value cstTwo = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(2)); auto scalarTensorType = rewriter.getType( ArrayRef{}, rewriter.getIntegerType(64, /*signed*/ 1)); auto selectTensorType = rewriter.getType( @@ -859,8 +870,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( auto interType = rewriter.getType( interShape, resultType.getOptionalDtype()); - Value modeVal = rewriter.create( - binder.getLoc(), rewriter.getStringAttr("floor")); + Value modeVal = Torch::ConstantStrOp::create( + rewriter, binder.getLoc(), rewriter.getStringAttr("floor")); for (int i = 0; i < axesSize; i++) { if (axes[i] < 0) axes[i] += rank; @@ -868,69 +879,78 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( continue; auto opType = axes[i] == lastChangeDim ? resultType : interType; - Value axis = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(axes[i])); - Value k = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i)); - Value kTensor = rewriter.create( - binder.getLoc(), scalarTensorType, k); - Value sel = rewriter.create( - binder.getLoc(), selectTensorType, shape, cstZero, kTensor); - Value outputDimSize = rewriter.create( - binder.getLoc(), rewriter.getType(), sel); - Value inputDimSize = rewriter.create( - binder.getLoc(), input, - rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(axes[i]))); + Value axis = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(axes[i])); + Value k = Torch::ConstantIntOp::create(rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(i)); + Value kTensor = Torch::PrimNumToTensorScalarOp::create( + rewriter, binder.getLoc(), scalarTensorType, k); + Value sel = Torch::AtenIndexSelectOp::create( + rewriter, binder.getLoc(), selectTensorType, shape, cstZero, + kTensor); + Value outputDimSize = Torch::AtenItemOp::create( + rewriter, binder.getLoc(), rewriter.getType(), + sel); + Value inputDimSize = Torch::AtenSizeIntOp::create( + rewriter, binder.getLoc(), input, + Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(axes[i]))); if (inputShape[axes[i]] > resultShape[axes[i]]) { - Value sub = rewriter.create( - binder.getLoc(), inputDimSize, outputDimSize); - Value subTensor = rewriter.create( - binder.getLoc(), scalarTensorType, sub); - Value div = rewriter.create( - binder.getLoc(), scalarTensorType, subTensor, cstTwo, modeVal); - Value start = rewriter.create( - binder.getLoc(), rewriter.getType(), div); - Value end = rewriter.create( - binder.getLoc(), start, outputDimSize); - input = rewriter.create( - binder.getLoc(), opType, input, axis, start, end, cstOne); + Value sub = Torch::AtenSubIntOp::create( + rewriter, binder.getLoc(), inputDimSize, outputDimSize); + Value subTensor = Torch::PrimNumToTensorScalarOp::create( + rewriter, binder.getLoc(), scalarTensorType, sub); + Value div = Torch::AtenDivScalarModeOp::create( + rewriter, binder.getLoc(), scalarTensorType, subTensor, cstTwo, + modeVal); + Value start = Torch::AtenItemOp::create( + rewriter, binder.getLoc(), rewriter.getType(), + div); + Value end = Torch::AtenAddIntOp::create(rewriter, binder.getLoc(), + start, outputDimSize); + input = Torch::AtenSliceTensorOp::create(rewriter, binder.getLoc(), + opType, input, axis, start, + end, cstOne); } else { - Value sub = rewriter.create( - binder.getLoc(), outputDimSize, inputDimSize); - Value subTensor = rewriter.create( - binder.getLoc(), scalarTensorType, sub); - Value div = rewriter.create( - binder.getLoc(), scalarTensorType, subTensor, cstTwo, modeVal); - Value start = rewriter.create( - binder.getLoc(), rewriter.getType(), div); - Value end = rewriter.create( - binder.getLoc(), start, inputDimSize); + Value sub = Torch::AtenSubIntOp::create( + rewriter, binder.getLoc(), outputDimSize, inputDimSize); + Value subTensor = Torch::PrimNumToTensorScalarOp::create( + rewriter, binder.getLoc(), scalarTensorType, sub); + Value div = Torch::AtenDivScalarModeOp::create( + rewriter, binder.getLoc(), scalarTensorType, subTensor, cstTwo, + modeVal); + Value start = Torch::AtenItemOp::create( + rewriter, binder.getLoc(), rewriter.getType(), + div); + Value end = Torch::AtenAddIntOp::create(rewriter, binder.getLoc(), + start, inputDimSize); SmallVector zerosShapeValues; for (int j = 0; j < rank; j++) { if (j == axes[i]) { zerosShapeValues.push_back(outputDimSize); } else { - Value dimSize = rewriter.create( - binder.getLoc(), input, - rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(j))); + Value dimSize = Torch::AtenSizeIntOp::create( + rewriter, binder.getLoc(), input, + Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(j))); zerosShapeValues.push_back(dimSize); } } - Value zerosShapeList = rewriter.create( - binder.getLoc(), + Value zerosShapeList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), rewriter.getType( rewriter.getType()), zerosShapeValues); - Value zeros = rewriter.create( - binder.getLoc(), opType, zerosShapeList, none, none, none, - none); - input = rewriter.create( - binder.getLoc(), opType, zeros, input, axis, start, end, - cstOne); + Value zeros = Torch::AtenZerosOp::create(rewriter, binder.getLoc(), + opType, zerosShapeList, + none, none, none, none); + input = Torch::AtenSliceScatterOp::create(rewriter, binder.getLoc(), + opType, zeros, input, + axis, start, end, cstOne); } } @@ -968,8 +988,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( auto minSplatAttr = SplatElementsAttr::get( resultType.toBuiltinTensor(), rewriter.getFloatAttr(resultDtype, minValue)); - min = rewriter.create( - binder.getLoc(), resultType, minSplatAttr); + min = Torch::ValueTensorLiteralOp::create(rewriter, binder.getLoc(), + resultType, minSplatAttr); } if (!max && binder.op->hasAttr("torch.onnx.max")) { float maxValue; @@ -979,8 +999,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( auto maxSplatAttr = SplatElementsAttr::get( resultType.toBuiltinTensor(), rewriter.getFloatAttr(resultDtype, maxValue)); - max = rewriter.create( - binder.getLoc(), resultType, maxSplatAttr); + max = Torch::ValueTensorLiteralOp::create(rewriter, binder.getLoc(), + resultType, maxSplatAttr); } if (!min && !max) { @@ -1021,12 +1041,12 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( .getDtype(); auto nonzeroType = rewriter.getType(nonzeroShape, dtype); - Value indexVal = rewriter.create( - binder.getLoc(), nonzeroType, conditionTensor); - Value cstZero = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(0)); - Value cstNegOne = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(-1)); + Value indexVal = Torch::AtenNonzeroOp::create( + rewriter, binder.getLoc(), nonzeroType, conditionTensor); + Value cstZero = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value cstNegOne = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(-1)); int64_t numElements = 1; for (auto i : shapeSizes) { numElements *= i; @@ -1034,8 +1054,9 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( SmallVector flattenShape = {numElements}; auto flattenType = rewriter.getType( flattenShape, resultType.getDtype()); - Value flattenTensor = rewriter.create( - binder.getLoc(), flattenType, operand, cstZero, cstNegOne); + Value flattenTensor = Torch::AtenFlattenUsingIntsOp::create( + rewriter, binder.getLoc(), flattenType, operand, cstZero, + cstNegOne); rewriter.replaceOpWithNewOp( binder.op, resultType, flattenTensor, cstZero, indexVal); return success(); @@ -1049,10 +1070,10 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( .getDtype(); auto nonzeroType = rewriter.getType(nonzeroShape, dtype); - Value indexVal = rewriter.create( - binder.getLoc(), nonzeroType, conditionTensor); - Value dimVal = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(axis)); + Value indexVal = Torch::AtenNonzeroOp::create( + rewriter, binder.getLoc(), nonzeroType, conditionTensor); + Value dimVal = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(axis)); rewriter.replaceOpWithNewOp( binder.op, resultType, operand, dimVal, indexVal); return success(); @@ -1071,10 +1092,10 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( .getWithSizesAndDtype(/*optionalSizes=*/std::nullopt, /*optionalDtype=*/nullptr); Type listType = Torch::ListType::get(listElemType); - Value tensorList = rewriter.create( - binder.op->getLoc(), listType, tensors); - Value cstDim = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(dim)); + Value tensorList = Torch::PrimListConstructOp::create( + rewriter, binder.op->getLoc(), listType, tensors); + Value cstDim = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(dim)); rewriter.replaceOpWithNewOp(binder.op, resultType, tensorList, cstDim); return success(); @@ -1245,34 +1266,35 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( Value paddingList = createConstantIntList(binder, rewriter, padOnEachAxis); - Value zero = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value zero = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(0)); // Index the imageShape and blockShape tensors, as AtenCol2imOp expects // them to be int lists. auto select = [&](Value v, Value k, Torch::ValueTensorType ty) -> Value { - Value kTensor = rewriter.create( - binder.getLoc(), + Value kTensor = Torch::PrimNumToTensorScalarOp::create( + rewriter, binder.getLoc(), Torch::ValueTensorType::get( binder.op->getContext(), ArrayRef{1}, rewriter.getIntegerType(64, /*signed*/ 1)), k); - auto sel = rewriter.create( - binder.getLoc(), + auto sel = Torch::AtenIndexSelectOp::create( + rewriter, binder.getLoc(), Torch::ValueTensorType::get(ty.getContext(), ArrayRef{1}, ty.getOptionalDtype()), v, zero, kTensor); - Value item = rewriter.create( - binder.getLoc(), rewriter.getType(), sel); + Value item = Torch::AtenItemOp::create( + rewriter, binder.getLoc(), rewriter.getType(), + sel); return item; }; SmallVector imageShapeContainer, blockShapeContainer; for (int64_t i = 0; i < imageShapeSizes[0]; ++i) { - Value k = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i)); + Value k = Torch::ConstantIntOp::create(rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(i)); // Passing in the shapeType of each of these tensors avoids // repeated casts, as these have already been calculated. @@ -1280,12 +1302,12 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( blockShapeContainer.push_back(select(blockShape, k, blockShapeTy)); } - Value imageShapeAsList = rewriter.create( - binder.getLoc(), + Value imageShapeAsList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), imageShapeContainer); - Value blockShapeAsList = rewriter.create( - binder.getLoc(), + Value blockShapeAsList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), blockShapeContainer); @@ -1373,12 +1395,12 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( // Use the padding values for (int64_t pad : padding) - paddingValues.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(pad))); + paddingValues.push_back(Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(pad))); } else if (autoPad == "VALID") { for (int64_t pad : defaultPadding) - paddingValues.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(pad))); + paddingValues.push_back(Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(pad))); } else { const bool isSameLower = autoPad == "SAME_LOWER"; const unsigned spatialRank = rank - 2; @@ -1386,65 +1408,64 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( for (unsigned dimIdx = 0; dimIdx < spatialRank; dimIdx++) { // dilatedSize = dilations[dimIdx]*(weightShape[dimIdx + 2] - 1) + 1 - Value cstOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - Value dilationValue = rewriter.create( - loc, rewriter.getI64IntegerAttr(dilations[dimIdx])); + Value cstOne = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(1)); + Value dilationValue = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(dilations[dimIdx])); Value weightDimSize = Torch::getTensorDimSize(rewriter, weight, dimIdx + 2); - Value weightMinusOne = rewriter.create( - loc, weightDimSize, cstOne); - Value dilationMulWeight = rewriter.create( - loc, dilationValue, weightMinusOne); - Value dilatedKernelSize = rewriter.create( - loc, dilationMulWeight, cstOne); + Value weightMinusOne = Torch::AtenSubIntOp::create( + rewriter, loc, weightDimSize, cstOne); + Value dilationMulWeight = Torch::AtenMulIntOp::create( + rewriter, loc, dilationValue, weightMinusOne); + Value dilatedKernelSize = Torch::AtenAddIntOp::create( + rewriter, loc, dilationMulWeight, cstOne); // totalPad = (((inputShape[dimIdx + 2] + strides[dimIdx] -1) / // strides[dimIdx]) - 1) * strides[dimIdx] + // dilatedKernelSize - inputShape[dimIdx + 2]; - Value stridesValue = rewriter.create( - loc, rewriter.getI64IntegerAttr(strides[dimIdx])); + Value stridesValue = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(strides[dimIdx])); Value inputDimSize = Torch::getTensorDimSize(rewriter, input, dimIdx + 2); - Value stridesMinusOne = - rewriter.create(loc, stridesValue, cstOne); - Value inputStrides = rewriter.create( - loc, inputDimSize, stridesMinusOne); - inputStrides = rewriter.create( - loc, inputStrides, stridesValue); - inputStrides = - rewriter.create(loc, inputStrides, cstOne); - inputStrides = rewriter.create( - loc, inputStrides, stridesValue); - Value strideWithDilation = rewriter.create( - loc, inputStrides, dilatedKernelSize); - Value totalPad = rewriter.create( - loc, strideWithDilation, inputDimSize); + Value stridesMinusOne = Torch::AtenSubIntOp::create( + rewriter, loc, stridesValue, cstOne); + Value inputStrides = Torch::AtenAddIntOp::create( + rewriter, loc, inputDimSize, stridesMinusOne); + inputStrides = Torch::AtenFloordivIntOp::create( + rewriter, loc, inputStrides, stridesValue); + inputStrides = Torch::AtenSubIntOp::create(rewriter, loc, + inputStrides, cstOne); + inputStrides = Torch::AtenMulIntOp::create( + rewriter, loc, inputStrides, stridesValue); + Value strideWithDilation = Torch::AtenAddIntOp::create( + rewriter, loc, inputStrides, dilatedKernelSize); + Value totalPad = Torch::AtenSubIntOp::create( + rewriter, loc, strideWithDilation, inputDimSize); // totalPad = totalPad > 0 ? totalPad : 0; - Value cstZero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); + Value cstZero = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(0)); totalPad = - rewriter.create(loc, totalPad, cstZero); + Torch::PrimMaxIntOp::create(rewriter, loc, totalPad, cstZero); // padding[dimIdx] = // isSameLower ? ((totalPad + 1) / 2) : (totalPad / 2); // padding[spatialRank + dimIdx] = totalPad - padding[dimIdx]; - Value cstTwo = rewriter.create( - loc, rewriter.getI64IntegerAttr(2)); + Value cstTwo = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(2)); if (isSameLower) { auto padPlusOne = - rewriter.create(loc, totalPad, cstOne); - paddingValues[dimIdx] = rewriter.create( - loc, padPlusOne, cstTwo); + Torch::AtenAddIntOp::create(rewriter, loc, totalPad, cstOne); + paddingValues[dimIdx] = Torch::AtenFloordivIntOp::create( + rewriter, loc, padPlusOne, cstTwo); } else { - paddingValues[dimIdx] = rewriter.create( - loc, totalPad, cstTwo); + paddingValues[dimIdx] = Torch::AtenFloordivIntOp::create( + rewriter, loc, totalPad, cstTwo); } - paddingValues[spatialRank + dimIdx] = - rewriter.create(loc, totalPad, - paddingValues[dimIdx]); + paddingValues[spatialRank + dimIdx] = Torch::AtenSubIntOp::create( + rewriter, loc, totalPad, paddingValues[dimIdx]); } } @@ -1459,13 +1480,13 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( Value paddedInput = input; Value paddingList; - Value cstZero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); + Value cstZero = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(0)); if (paddingValues.size() != 2 * (rank - 2)) { cstPadding = paddingValues; - paddingList = rewriter.create( - loc, + paddingList = Torch::PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get( Torch::IntType::get(binder.op->getContext())), cstPadding); @@ -1493,8 +1514,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( if (matchedPads) { for (unsigned i = 0; i < paddingValues.size() / 2; i++) cstPadding.push_back(paddingValues[i]); - paddingList = rewriter.create( - loc, + paddingList = Torch::PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get( Torch::IntType::get(binder.op->getContext())), cstPadding); @@ -1510,28 +1531,26 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( } // The conv op itself will have no padding since the actual padding // is performed using the torch.pad preceding it. - paddingList = rewriter.create( - loc, + paddingList = Torch::PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get( Torch::IntType::get(binder.op->getContext())), inputPaddingList); Value padsSizeList = - rewriter - .create( - loc, - Torch::ListType::get( - rewriter.getType()), - padsRearrange) + Torch::PrimListConstructOp::create( + rewriter, loc, + Torch::ListType::get(rewriter.getType()), + padsRearrange) .getResult(); - Value modeVal = rewriter.create( - loc, rewriter.getStringAttr("constant")); + Value modeVal = Torch::ConstantStrOp::create( + rewriter, loc, rewriter.getStringAttr("constant")); Value constantValue; if (isa(inputTensorType.getDtype())) constantValue = cstZero; if (isa(inputTensorType.getDtype())) - constantValue = rewriter.create( - loc, rewriter.getF64FloatAttr(0.0f)); + constantValue = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(0.0f)); auto getPadOutputSizeForInput = [&](int64_t low, int64_t high, int64_t inputSize) { @@ -1555,45 +1574,46 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( auto padTy = rewriter.getType( newInputShape, inputTensorType.getDtype()); - paddedInput = rewriter.create( - loc, padTy, input, padsSizeList, modeVal, constantValue); + paddedInput = + Torch::AtenPadOp::create(rewriter, loc, padTy, input, + padsSizeList, modeVal, constantValue); } } for (int64_t i : dilations) { - cstDilations.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(i))); + cstDilations.push_back(Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(i))); } for (int64_t i : strides) { - cstStrides.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(i))); + cstStrides.push_back(Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(i))); } cstOutputPadding = {cstZero, cstZero}; - Value dilationsList = rewriter.create( - loc, + Value dilationsList = Torch::PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstDilations); - Value stridesList = rewriter.create( - loc, + Value stridesList = Torch::PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstStrides); - Value outputPaddingList = rewriter.create( - loc, + Value outputPaddingList = Torch::PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstOutputPadding); - Value transposed = rewriter.create(loc, false); + Value transposed = Torch::ConstantBoolOp::create(rewriter, loc, false); Value bias; if (binder.op->getNumOperands() == 3) { if (binder.tensorOperandAtIndex(bias, 2)) { return failure(); } } else { - bias = rewriter.create(loc); + bias = Torch::ConstantNoneOp::create(rewriter, loc); } - Value cstGroup = rewriter.create( - loc, rewriter.getI64IntegerAttr(group)); + Value cstGroup = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(group)); rewriter.replaceOpWithNewOp( binder.op, resultType, paddedInput, weight, bias, stridesList, @@ -1681,33 +1701,35 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return rewriter.notifyMatchFailure( binder.op, "strides list size does not match the number of axes"); - Value scale = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value scale = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(1.0)); if (binder.tensorOperandAtIndex(inputZp, 2)) { - inputZp = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(0)); + inputZp = Torch::ConstantIntOp::create(rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(0)); } else { - inputZp = rewriter.create( - binder.getLoc(), rewriter.getType(), inputZp); + inputZp = Torch::AtenItemOp::create( + rewriter, binder.getLoc(), rewriter.getType(), + inputZp); } if (binder.tensorOperandAtIndex(weightZp, 3)) - weightZp = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(0)); + weightZp = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(0)); // TODO: support per channel quantization if weightZp is a 1-D tensor if (auto zpTy = dyn_cast(weightZp.getType())) { for (auto dim : zpTy.getSizes()) if (dim != 1) return failure(); - weightZp = rewriter.create( - binder.getLoc(), rewriter.getType(), weightZp); + weightZp = Torch::AtenItemOp::create( + rewriter, binder.getLoc(), rewriter.getType(), + weightZp); } SmallVector cstPadding; if (padding.size() != 2 * (rank - 2)) { for (int64_t i : padding) { - cstPadding.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i))); + cstPadding.push_back(Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(i))); } } else { for (unsigned i = 0; i < padding.size() / 2; i++) { @@ -1718,13 +1740,14 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, "unsupported conversion: padding values for the beginning " "and ending along each spatial axis must be equal"); - cstPadding.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); + cstPadding.push_back(Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(padding[i]))); } } - Value paddingList = rewriter.create( - binder.getLoc(), + Value paddingList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), rewriter.getType( rewriter.getType()), cstPadding); @@ -1734,17 +1757,17 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( Value outputPaddingList = createConstantIntList(binder, rewriter, {0, 0}); Value transposed = - rewriter.create(binder.getLoc(), false); - Value bias = rewriter.create(binder.getLoc()); - Value cstGroup = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(group)); + Torch::ConstantBoolOp::create(rewriter, binder.getLoc(), false); + Value bias = Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); + Value cstGroup = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(group)); Type inputQTy = getQTorchTypeFromTorchIntType(inputTy); Type weightQTy = getQTorchTypeFromTorchIntType(weightTy); - input = rewriter.create( - binder.getLoc(), inputQTy, input, scale, inputZp); - weight = rewriter.create( - binder.getLoc(), weightQTy, weight, scale, weightZp); + input = Torch::Aten_MakePerTensorQuantizedTensorOp::create( + rewriter, binder.getLoc(), inputQTy, input, scale, inputZp); + weight = Torch::Aten_MakePerTensorQuantizedTensorOp::create( + rewriter, binder.getLoc(), weightQTy, weight, scale, weightZp); rewriter.replaceOpWithNewOp( binder.op, resultType, input, weight, bias, stridesList, @@ -1904,8 +1927,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( cstOutputPadding; if (padding.size() != 2 * (rank - 2)) { for (int64_t i : padding) { - cstPadding.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i))); + cstPadding.push_back(Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(i))); } } else { for (unsigned i = 0; i < padding.size() / 2; i++) { @@ -1917,51 +1940,52 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( "unsupported conversion: padding values for the beginning " "and ending along each spatial axis must be equal"); } - cstPadding.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); + cstPadding.push_back(Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(padding[i]))); } } for (int64_t i : dilations) { - cstDilations.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i))); + cstDilations.push_back(Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(i))); } for (int64_t i : strides) { - cstStrides.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i))); + cstStrides.push_back(Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(i))); } for (int64_t i : outputPadding) { - cstOutputPadding.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i))); + cstOutputPadding.push_back(Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(i))); } - Value paddingList = rewriter.create( - binder.getLoc(), + Value paddingList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstPadding); - Value dilationsList = rewriter.create( - binder.getLoc(), + Value dilationsList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstDilations); - Value stridesList = rewriter.create( - binder.getLoc(), + Value stridesList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstStrides); - Value outputPaddingList = rewriter.create( - binder.getLoc(), + Value outputPaddingList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstOutputPadding); Value transposed = - rewriter.create(binder.getLoc(), true); + Torch::ConstantBoolOp::create(rewriter, binder.getLoc(), true); Value bias; if (binder.op->getNumOperands() == 3) { if (binder.tensorOperandAtIndex(bias, 2)) { return failure(); } } else { - bias = rewriter.create(binder.getLoc()); + bias = Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); } - Value cstGroup = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(group)); + Value cstGroup = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(group)); rewriter.replaceOpWithNewOp( binder.op, resultType, input, weight, bias, stridesList, @@ -2012,47 +2036,48 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( // deal with neg axis: if (axis < 0) axis += rank int64_t rank = cast(operand.getType()).getSizes().size(); - Value rankVal = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value rankVal = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), rank)); - Value cstZero = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(0)); - Value cstOne = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); - - Value axisScalar = rewriter.create( - binder.getLoc(), rewriter.getType(), axisTensor); - Value isNegative = rewriter.create( - binder.getLoc(), axisScalar, cstZero); + Value cstZero = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value cstOne = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(1)); + + Value axisScalar = Torch::AtenItemOp::create( + rewriter, binder.getLoc(), rewriter.getType(), + axisTensor); + Value isNegative = Torch::AtenLtIntOp::create(rewriter, binder.getLoc(), + axisScalar, cstZero); isNegative = - rewriter.create(binder.getLoc(), isNegative); - Value finalOffset = rewriter.create( - binder.getLoc(), isNegative, rankVal); - Value axis = rewriter.create( - binder.getLoc(), axisScalar, finalOffset); - Value none = rewriter.create(binder.getLoc()); + Torch::AtenIntBoolOp::create(rewriter, binder.getLoc(), isNegative); + Value finalOffset = Torch::AtenMulIntOp::create( + rewriter, binder.getLoc(), isNegative, rankVal); + Value axis = Torch::AtenAddIntOp::create(rewriter, binder.getLoc(), + axisScalar, finalOffset); + Value none = Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); Value res; if (reverse) { - Value dims = rewriter.create( - binder.getLoc(), + Value dims = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), rewriter.getType( rewriter.getType()), SmallVector{axis}); - Value flip = rewriter.create( - binder.getLoc(), resultType, operand, dims); - Value cumsum = rewriter.create( - binder.getLoc(), resultType, flip, axis, none); - res = rewriter.create(binder.getLoc(), resultType, - cumsum, dims); + Value flip = Torch::AtenFlipOp::create(rewriter, binder.getLoc(), + resultType, operand, dims); + Value cumsum = Torch::AtenCumsumOp::create( + rewriter, binder.getLoc(), resultType, flip, axis, none); + res = Torch::AtenFlipOp::create(rewriter, binder.getLoc(), resultType, + cumsum, dims); } else { - res = rewriter.create( - binder.getLoc(), resultType, operand, axis, none); + res = Torch::AtenCumsumOp::create(rewriter, binder.getLoc(), + resultType, operand, axis, none); } if (exclusive) - res = rewriter.create( - binder.getLoc(), resultType, res, operand, cstOne); + res = Torch::AtenSubTensorOp::create( + rewriter, binder.getLoc(), resultType, res, operand, cstOne); rewriter.replaceOp(binder.op, res); return success(); }); @@ -2078,32 +2103,33 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return rewriter.notifyMatchFailure(binder.op, "Expected input rank to be 4"); } - Value b = rewriter.create( - binder.getLoc(), input, - rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(0))); - Value c = rewriter.create( - binder.getLoc(), input, - rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1))); - Value h = rewriter.create( - binder.getLoc(), input, - rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(2))); - Value w = rewriter.create( - binder.getLoc(), input, - rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(3))); - Value cstBlockSize = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(blockSize)); - Value cstBlockSizeSquare = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(blockSize * blockSize)); - Value cDivBlockSizeSquare = rewriter.create( - binder.getLoc(), c, cstBlockSizeSquare); - cDivBlockSizeSquare = rewriter.create( - binder.getLoc(), cDivBlockSizeSquare); - Value reshapeSizesList = rewriter.create( - binder.getLoc(), + Value b = Torch::AtenSizeIntOp::create( + rewriter, binder.getLoc(), input, + Torch::ConstantIntOp::create(rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(0))); + Value c = Torch::AtenSizeIntOp::create( + rewriter, binder.getLoc(), input, + Torch::ConstantIntOp::create(rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(1))); + Value h = Torch::AtenSizeIntOp::create( + rewriter, binder.getLoc(), input, + Torch::ConstantIntOp::create(rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(2))); + Value w = Torch::AtenSizeIntOp::create( + rewriter, binder.getLoc(), input, + Torch::ConstantIntOp::create(rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(3))); + Value cstBlockSize = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(blockSize)); + Value cstBlockSizeSquare = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(blockSize * blockSize)); + Value cDivBlockSizeSquare = Torch::AtenDivIntOp::create( + rewriter, binder.getLoc(), c, cstBlockSizeSquare); + cDivBlockSizeSquare = Torch::AtenIntFloatOp::create( + rewriter, binder.getLoc(), cDivBlockSizeSquare); + Value reshapeSizesList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get(Torch::IntType::get(input.getContext())), llvm::SmallVector{b, cstBlockSize, cstBlockSize, cDivBlockSizeSquare, h, w}); @@ -2114,8 +2140,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( SmallVector reshapeSizesInt{ inputSizes[0], blockSize, blockSize, cDivBlockSizeSquareInt, inputSizes[2], inputSizes[3]}; - Value reshapedInput = rewriter.create( - binder.getLoc(), + Value reshapedInput = Torch::AtenReshapeOp::create( + rewriter, binder.getLoc(), inputTy.getWithSizesAndDtype(reshapeSizesInt, inputTy.getOptionalDtype()), input, reshapeSizesList); @@ -2151,12 +2177,12 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return rewriter.notifyMatchFailure( binder.op, "Failed to create TorchTranspose op"); - Value hMulBlockSize = rewriter.create( - binder.getLoc(), h, cstBlockSize); - Value wMulBlockSize = rewriter.create( - binder.getLoc(), w, cstBlockSize); - reshapeSizesList = rewriter.create( - binder.getLoc(), + Value hMulBlockSize = Torch::AtenMulIntOp::create( + rewriter, binder.getLoc(), h, cstBlockSize); + Value wMulBlockSize = Torch::AtenMulIntOp::create( + rewriter, binder.getLoc(), w, cstBlockSize); + reshapeSizesList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get(Torch::IntType::get(input.getContext())), llvm::SmallVector{b, cDivBlockSizeSquare, hMulBlockSize, wMulBlockSize}); @@ -2278,20 +2304,20 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( // get attributes as constant values SmallVector dilationValues, padValues, strideValues; for (auto i : dilations) - dilationValues.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(i))); + dilationValues.push_back(Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(i))); for (auto i : pads) - padValues.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(i))); + padValues.push_back(Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(i))); for (auto i : strides) - strideValues.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(i))); - Value groupValue = rewriter.create( - loc, rewriter.getI64IntegerAttr(group)); - Value offsetGroupValue = rewriter.create( - loc, rewriter.getI64IntegerAttr(offsetGroup)); - Value useMaskValue = rewriter.create( - loc, rewriter.getBoolAttr(useMask)); + strideValues.push_back(Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(i))); + Value groupValue = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(group)); + Value offsetGroupValue = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(offsetGroup)); + Value useMaskValue = Torch::ConstantBoolOp::create( + rewriter, loc, rewriter.getBoolAttr(useMask)); rewriter.replaceOpWithNewOp( binder.op, resultType, input, weight, offset, mask, bias, strideValues[0], strideValues[1], padValues[0], padValues[1], @@ -2356,34 +2382,32 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( "for floating point input not present"); if (isPerTensorQuantization) { - scale = rewriter.create( - loc, rewriter.getType(), scale); + scale = Torch::AtenItemOp::create( + rewriter, loc, rewriter.getType(), scale); Type zeropointTy = rewriter.getType(); if (fpOperand) zeropointTy = rewriter.getType(); zeropoint = - rewriter.create(loc, zeropointTy, zeropoint); + Torch::AtenItemOp::create(rewriter, loc, zeropointTy, zeropoint); } if (!fpOperand) { Value quantize; // Case 1: Per-Tensor Quantization for non-floating point input. if (isPerTensorQuantization) { - quantize = - rewriter.create( - loc, qTensorTy, operand, scale, zeropoint); + quantize = Torch::Aten_MakePerTensorQuantizedTensorOp::create( + rewriter, loc, qTensorTy, operand, scale, zeropoint); } else { // Case 2: Per-Channel Quantization for non-floating point input. int64_t axis; if (binder.s64IntegerAttr(axis, "axis", 1)) return failure(); - Value cstAxis = rewriter.create( - loc, rewriter.getI64IntegerAttr(axis)); - quantize = - rewriter.create( - loc, qTensorTy, operand, scale, zeropoint, cstAxis); + Value cstAxis = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(axis)); + quantize = Torch::Aten_MakePerChannelQuantizedTensorOp::create( + rewriter, loc, qTensorTy, operand, scale, zeropoint, cstAxis); } rewriter.replaceOpWithNewOp( binder.op, resultType, quantize); @@ -2391,22 +2415,22 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( } // Case 3: Per-Tensor Quantization for floating point input. - Value none = rewriter.create(loc); - Value cstFalse = rewriter.create(loc, false); + Value none = Torch::ConstantNoneOp::create(rewriter, loc); + Value cstFalse = Torch::ConstantBoolOp::create(rewriter, loc, false); auto tyVal = Torch::getScalarTypeForType(resultType.getDtype()); - Value tyConst = rewriter.create( - loc, rewriter.getType(), + Value tyConst = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), static_cast(tyVal))); - Value toDtype = rewriter.create( - loc, resultType, operand, tyConst, + Value toDtype = Torch::AtenToDtypeOp::create( + rewriter, loc, resultType, operand, tyConst, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); - Value one = rewriter.create( - loc, rewriter.getF64FloatAttr(1.0)); - Value sub = rewriter.create( - loc, resultType, toDtype, zeropoint, one); + Value one = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(1.0)); + Value sub = Torch::AtenSubScalarOp::create(rewriter, loc, resultType, + toDtype, zeropoint, one); rewriter.replaceOpWithNewOp( binder.op, resultType, sub, scale); return success(); @@ -2442,7 +2466,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( Value ratio, trainingMode; if (numOperands == 3) { - ratio = rewriter.create(loc, operands[1]); + ratio = + Torch::AtenFloatImplicitOp::create(rewriter, loc, operands[1]); Value trainVal = operands[2]; auto trainTensorType = dyn_cast(trainVal.getType()); @@ -2465,26 +2490,28 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( trainVal.getDefiningOp()) { auto val = cast(valueTensorLiteralOp.getValue()) .getSplatValue(); - trainingMode = rewriter.create(loc, val); + trainingMode = Torch::ConstantBoolOp::create(rewriter, loc, val); } else { Value trainingModeScalar = - rewriter.create(loc, operands[2]); - Value cstOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - trainingMode = rewriter.create( - loc, trainingModeScalar, cstOne); + Torch::AtenIntImplicitOp::create(rewriter, loc, operands[2]); + Value cstOne = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(1)); + trainingMode = Torch::AtenEqIntOp::create( + rewriter, loc, trainingModeScalar, cstOne); } } else if (numOperands == 2) { - ratio = rewriter.create(loc, operands[1]); - trainingMode = rewriter.create(loc, false); + ratio = + Torch::AtenFloatImplicitOp::create(rewriter, loc, operands[1]); + trainingMode = Torch::ConstantBoolOp::create(rewriter, loc, false); } else { - ratio = rewriter.create( - loc, rewriter.getF64FloatAttr(0.5)); - trainingMode = rewriter.create(loc, false); + ratio = Torch::ConstantFloatOp::create(rewriter, loc, + rewriter.getF64FloatAttr(0.5)); + trainingMode = Torch::ConstantBoolOp::create(rewriter, loc, false); } - Value dropout = rewriter.create( - loc, resultType, /*input=*/operands[0], ratio, trainingMode); + Value dropout = Torch::AtenDropoutOp::create(rewriter, loc, resultType, + /*input=*/operands[0], + ratio, trainingMode); if (binder.op->getNumResults() == 1) { rewriter.replaceOp(binder.op, dropout); @@ -2493,12 +2520,13 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( Torch::ValueTensorType maskType; if (binder.tensorResultTypeAtIndex(maskType, 1)) return failure(); - Value dtype = rewriter.create( - loc, rewriter.getI64IntegerAttr( - (int64_t)torch_upstream::ScalarType::Bool)); - Value none = rewriter.create(loc); - Value mask = rewriter.create( - loc, maskType, operands[0], dtype, /*layout=*/none, + Value dtype = Torch::ConstantIntOp::create( + rewriter, loc, + rewriter.getI64IntegerAttr( + (int64_t)torch_upstream::ScalarType::Bool)); + Value none = Torch::ConstantNoneOp::create(rewriter, loc); + Value mask = Torch::AtenOnesLikeOp::create( + rewriter, loc, maskType, operands[0], dtype, /*layout=*/none, /*device=*/none, /*pin_memory=*/none, /*memory_format=*/none); rewriter.replaceOp(binder.op, {dropout, mask}); return success(); @@ -2519,65 +2547,66 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( // scale = ( max(0, max(input)) - min(0, min(input)) ) / 255 Value inputMax = - rewriter.create(loc, scaleType, input); + Torch::AtenMaxOp::create(rewriter, loc, scaleType, input); Value inputMin = - rewriter.create(loc, scaleType, input); - Value constantZero = rewriter.create( - loc, rewriter.getF64FloatAttr(0)); - Value constantOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); + Torch::AtenMinOp::create(rewriter, loc, scaleType, input); + Value constantZero = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(0)); + Value constantOne = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(1)); Value zeroTensor = createRank0Tensor(rewriter, loc, scaleType, constantZero); - Value inputMaxW0 = rewriter.create( - loc, scaleType, inputMax, zeroTensor); - Value inputMinW0 = rewriter.create( - loc, scaleType, inputMin, zeroTensor); - Value scaleTensor = rewriter.create( - loc, scaleType, inputMaxW0, inputMinW0, constantOne); + Value inputMaxW0 = Torch::AtenMaximumOp::create( + rewriter, loc, scaleType, inputMax, zeroTensor); + Value inputMinW0 = Torch::AtenMinimumOp::create( + rewriter, loc, scaleType, inputMin, zeroTensor); + Value scaleTensor = Torch::AtenSubTensorOp::create( + rewriter, loc, scaleType, inputMaxW0, inputMinW0, constantOne); // Note: the following is hard-coded for ui8 - Value width = rewriter.create( - loc, rewriter.getF64FloatAttr(255)); + Value width = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(255)); Value widthTensor = createRank0Tensor(rewriter, loc, scaleType, width); - scaleTensor = rewriter.create( - loc, scaleType, scaleTensor, widthTensor); + scaleTensor = Torch::AtenDivTensorOp::create(rewriter, loc, scaleType, + scaleTensor, widthTensor); // compute the preZeroPoint = 0 - (inputMin/scale) // compute the zeroPoint = cast ( round (clip or saturate // (preZeroPoint))) - Value preZeroPoint = rewriter.create( - loc, scaleType, inputMin, scaleTensor); - preZeroPoint = rewriter.create( - loc, scaleType, zeroTensor, preZeroPoint, constantOne); + Value preZeroPoint = Torch::AtenDivTensorOp::create( + rewriter, loc, scaleType, inputMin, scaleTensor); + preZeroPoint = Torch::AtenSubTensorOp::create( + rewriter, loc, scaleType, zeroTensor, preZeroPoint, constantOne); // saturate to interval [0, 255] - preZeroPoint = rewriter.create( - loc, scaleType, preZeroPoint, /*min=*/constantZero, /*max=*/width); + preZeroPoint = + Torch::AtenClampOp::create(rewriter, loc, scaleType, preZeroPoint, + /*min=*/constantZero, /*max=*/width); // round, then cast to uint8 preZeroPoint = - rewriter.create(loc, scaleType, preZeroPoint); + Torch::AtenRoundOp::create(rewriter, loc, scaleType, preZeroPoint); Type qTy = rewriter.getType(); auto qTensorTy = rewriter.getType( resultType.getOptionalSizes(), qTy); auto torchqTy = Torch::getScalarTypeForType(qTy); - Value tyConst = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value tyConst = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), static_cast(torchqTy))); - Value none = rewriter.create(loc); - Value cstFalse = rewriter.create(loc, false); - Value zeroPointTensor = rewriter.create( - loc, zeroPointType, preZeroPoint, tyConst, + Value none = Torch::ConstantNoneOp::create(rewriter, loc); + Value cstFalse = Torch::ConstantBoolOp::create(rewriter, loc, false); + Value zeroPointTensor = Torch::AtenToDtypeOp::create( + rewriter, loc, zeroPointType, preZeroPoint, tyConst, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); // extract scale and zeroPoint scalars to pass to // AtenQuantizePerTensorOp - zeroPoint = rewriter.create( - loc, rewriter.getType(), zeroPointTensor); - scale = rewriter.create( - loc, rewriter.getType(), scaleTensor); - Value quantizedTensor = rewriter.create( - loc, qTensorTy, input, scale, zeroPoint, tyConst); + zeroPoint = Torch::AtenItemOp::create( + rewriter, loc, rewriter.getType(), zeroPointTensor); + scale = Torch::AtenItemOp::create( + rewriter, loc, rewriter.getType(), scaleTensor); + Value quantizedTensor = Torch::AtenQuantizePerTensorOp::create( + rewriter, loc, qTensorTy, input, scale, zeroPoint, tyConst); // get uint8 tensor output - Value output = rewriter.create(loc, resultType, - quantizedTensor); + Value output = Torch::AtenIntReprOp::create(rewriter, loc, resultType, + quantizedTensor); rewriter.replaceOp(binder.op, {output, scaleTensor, zeroPointTensor}); return success(); }); @@ -2603,10 +2632,10 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.f32FloatAttr(alpha, "alpha", 1.0) || binder.tensorResultType(resultType)) return failure(); - Value cstAlpha = rewriter.create( - loc, rewriter.getF64FloatAttr(alpha)); - Value cstOne = rewriter.create( - loc, rewriter.getF64FloatAttr(1.0)); + Value cstAlpha = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(alpha)); + Value cstOne = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(1.0)); rewriter.replaceOpWithNewOp( binder.op, resultType, input, cstAlpha, /*scale=*/cstOne, /*input_scale=*/cstOne); @@ -2665,8 +2694,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( // Variable to store 1-D onnx shape tensor, shapeSizes[0] has the // dimension size // A constant zero value - Value zero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); + Value zero = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(0)); // Variable to store pytorch int list of shape (dimension) SmallVector dimList; @@ -2675,13 +2704,13 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( // takes list of int for (int i = 0; i < shapeSizes[0]; i++) { // extract dim from shape - Value selectIndex = rewriter.create( - loc, rewriter.getType(), + Value selectIndex = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - Value extract = rewriter.create( - loc, selectResultType, shape, zero, selectIndex); - Value selectDim = rewriter.create( - loc, rewriter.getType(), extract); + Value extract = Torch::AtenSelectIntOp::create( + rewriter, loc, selectResultType, shape, zero, selectIndex); + Value selectDim = Torch::AtenItemOp::create( + rewriter, loc, rewriter.getType(), extract); // compute dim to pass to broadcast op. For non-broadcastable dims, // pass -1 Value dim; @@ -2690,17 +2719,17 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( // broadcasted // 2. we will explicitly disallow broadcasting dynamic dims that are // secretly 1. - dim = rewriter.create(loc, -1); + dim = Torch::ConstantIntOp::create(rewriter, loc, -1); // Assert dataShape[i + rankDiff] >= selectDim. If both are // constant, this should fold out. Value iv = - rewriter.create(loc, i + rankDifference); - auto sz = rewriter.create( - loc, rewriter.getType(), data, iv); + Torch::ConstantIntOp::create(rewriter, loc, i + rankDifference); + auto sz = Torch::AtenSizeIntOp::create( + rewriter, loc, rewriter.getType(), data, iv); Value gtSelect = - rewriter.create(loc, sz, selectDim); - rewriter.create( - loc, gtSelect, + Torch::AtenGeIntOp::create(rewriter, loc, sz, selectDim); + Torch::RuntimeAssertOp::create( + rewriter, loc, gtSelect, rewriter.getStringAttr( "onnx.Expand input has a dim that is not statically 1; " "expected this dim >= dim provided shape.")); @@ -2713,8 +2742,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( } dimList.push_back(dim); } - Value dimValueList = rewriter.create( - binder.getLoc(), + Value dimValueList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), dimList); rewriter.replaceOpWithNewOp( @@ -2739,15 +2768,16 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( shape[i] = Torch::kUnknownSize; } - Value cst0 = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(0)); - Value cst1 = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); - Value nVal = rewriter.create(binder.getLoc(), - operand, cst0); - Value mVal = rewriter.create(binder.getLoc(), - operand, cst1); - Value noneVal = rewriter.create(binder.getLoc()); + Value cst0 = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value cst1 = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value nVal = Torch::AtenSizeIntOp::create(rewriter, binder.getLoc(), + operand, cst0); + Value mVal = Torch::AtenSizeIntOp::create(rewriter, binder.getLoc(), + operand, cst1); + Value noneVal = + Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); std::optional dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); if (!dtypeIntTorch.has_value()) { @@ -2755,8 +2785,9 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, "unimplemented support for the given dtype conversion"); } - Value dtypeVal = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value())); + Value dtypeVal = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(dtypeIntTorch.value())); // diagonalIndex = 0 populates the main diagonal // diagonalIndex > 0 populates an upper diagonal @@ -2768,23 +2799,23 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return success(); } - Value diagVal = rewriter.create( - binder.getLoc(), + Value diagVal = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(std::abs(diagonalIndex))); Value newN, newM, dimVal, startVal; // get shapes of main diag eye op and zeros op if (diagonalIndex > 0) { newN = nVal; - newM = rewriter.create(binder.getLoc(), mVal, - diagVal); + newM = Torch::AtenSubIntOp::create(rewriter, binder.getLoc(), mVal, + diagVal); if (shape[1] != Torch::kUnknownSize) { shape[1] -= diagonalIndex; } dimVal = cst1; startVal = mVal; } else { - newN = rewriter.create(binder.getLoc(), nVal, - diagVal); + newN = Torch::AtenSubIntOp::create(rewriter, binder.getLoc(), nVal, + diagVal); newM = mVal; if (shape[0] != Torch::kUnknownSize) { shape[0] += diagonalIndex; @@ -2796,19 +2827,19 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( // create main diag eye op auto eyeResultType = rewriter.getType( shape, resultType.getOptionalDtype()); - Value eyeOp = rewriter.create( - binder.getLoc(), eyeResultType, newN, newM, dtypeVal, noneVal, - noneVal, noneVal); + Value eyeOp = Torch::AtenEyeMOp::create( + rewriter, binder.getLoc(), eyeResultType, newN, newM, dtypeVal, + noneVal, noneVal, noneVal); // create zeros op SmallVector zerosShapeValues = {nVal, mVal}; - Value zerosShapeList = rewriter.create( - binder.getLoc(), + Value zerosShapeList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), rewriter.getType( rewriter.getType()), zerosShapeValues); - Value zerosOp = rewriter.create( - binder.getLoc(), resultType, zerosShapeList, dtypeVal, noneVal, - noneVal, noneVal); + Value zerosOp = Torch::AtenZerosOp::create( + rewriter, binder.getLoc(), resultType, zerosShapeList, dtypeVal, + noneVal, noneVal, noneVal); // embeds the values of the eye matrix into zeros rewriter.replaceOpWithNewOp( @@ -2868,23 +2899,24 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( // If the right range is empty, add a dim of size 1 to the // right side of the shape: // cr = torch.unsqueeze(x, x.ndim) - Value rankConst = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(rank)); - collapsedRight = rewriter.create( - binder.getLoc(), baseType, operand, rankConst); + Value rankConst = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(rank)); + collapsedRight = Torch::AtenUnsqueezeOp::create( + rewriter, binder.getLoc(), baseType, operand, rankConst); } else { // Otherwise, collapse the right range into a single dimension: // cr = torch._prims.collapse(x, axis, x.ndim - 1) - Value axisConst = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(axis)); - Value rankLess1Const = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(rank - 1)); - collapsedRight = rewriter.create( - binder.getLoc(), baseType, operand, axisConst, rankLess1Const); + Value axisConst = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(axis)); + Value rankLess1Const = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(rank - 1)); + collapsedRight = Torch::PrimsCollapseOp::create( + rewriter, binder.getLoc(), baseType, operand, axisConst, + rankLess1Const); } - Value zero = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value zero = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(0)); if (axis <= 0) { // If the left range is empty, add a dim of size 1 to the @@ -2897,8 +2929,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( // Otherwise, collapse the left range into a single dimension: // torch._prims.collapse(cr, 0, axis - 1) - Value axisLess1Const = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(axis - 1)); + Value axisLess1Const = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(axis - 1)); rewriter.replaceOpWithNewOp( binder.op, resultType, collapsedRight, zero, axisLess1Const); return success(); @@ -2930,26 +2962,29 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( cast(shape.getType()); Type selectResultType = rewriter.getType( ArrayRef({}), shapeType.getOptionalDtype()); - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value zero = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); for (int i = 0; i < shapeSizes[0]; i++) { - Value selectIndex = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value selectIndex = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - Value extract = rewriter.create( - binder.getLoc(), selectResultType, shape, zero, selectIndex); - Value dim = rewriter.create( - binder.getLoc(), rewriter.getType(), extract); + Value extract = Torch::AtenSelectIntOp::create( + rewriter, binder.getLoc(), selectResultType, shape, zero, + selectIndex); + Value dim = Torch::AtenItemOp::create( + rewriter, binder.getLoc(), rewriter.getType(), + extract); dimList.push_back(dim); } - Value dimValueList = rewriter.create( - binder.getLoc(), + Value dimValueList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), dimList); - Value noneVal = rewriter.create(binder.getLoc()); + Value noneVal = + Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); // Get fill_value if it is present. // Assumption : resultDType and value attr type match. @@ -2960,8 +2995,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( Value splatvalue; // if no value attr is provided, default is 0.0 float value if (!attr) { - splatvalue = rewriter.create( - binder.getLoc(), rewriter.getF64FloatAttr(0.0)); + splatvalue = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getF64FloatAttr(0.0)); } // If its a dense resource attr we need to convert to a dense type: @@ -3004,12 +3039,12 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( value = intattr.getInt(); } splatvalue = - rewriter.create(binder.getLoc(), value); + Torch::ConstantIntOp::create(rewriter, binder.getLoc(), value); } if (auto fpattr = dyn_cast_or_null(splattr)) - splatvalue = rewriter.create( - binder.getLoc(), + splatvalue = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getF64FloatAttr(fpattr.getValueAsDouble())); rewriter.replaceOpWithNewOp( @@ -3031,12 +3066,13 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( .getWithSizesAndDtype(/*optionalSizes=*/std::nullopt, /*optionalDtype=*/nullptr); Type listType = Torch::ListType::get(listElemType); - Value tensorList = rewriter.create( - binder.op->getLoc(), listType, tensors); - Value cstEquation = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value tensorList = Torch::PrimListConstructOp::create( + rewriter, binder.op->getLoc(), listType, tensors); + Value cstEquation = Torch::ConstantStrOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getStringAttr(equation)); - Value cstNone = rewriter.create(binder.getLoc()); + Value cstNone = + Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); rewriter.replaceOpWithNewOp( binder.op, resultType, cstEquation, tensorList, /*path=*/cstNone); return success(); @@ -3055,12 +3091,12 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( } Location loc = binder.getLoc(); - Value a0 = rewriter.create( - loc, rewriter.getF64FloatAttr(0.42)); - Value a1 = rewriter.create( - loc, rewriter.getF64FloatAttr(-0.5)); - Value a2 = rewriter.create( - loc, rewriter.getF64FloatAttr(0.08)); + Value a0 = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(0.42)); + Value a1 = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(-0.5)); + Value a2 = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(0.08)); auto windowFunctionResult = windowFunctionImpl(binder, rewriter, size, a0, a1, a2, resultType, @@ -3086,12 +3122,12 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( } Location loc = binder.getLoc(); - Value a0 = rewriter.create( - loc, rewriter.getF64FloatAttr(0.5)); - Value a1 = rewriter.create( - loc, rewriter.getF64FloatAttr(-0.5)); - Value a2 = rewriter.create( - loc, rewriter.getF64FloatAttr(0.0)); + Value a0 = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(0.5)); + Value a1 = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(-0.5)); + Value a2 = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(0.0)); auto windowFunctionResult = windowFunctionImpl(binder, rewriter, size, a0, a1, a2, resultType, @@ -3117,12 +3153,12 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( } Location loc = binder.getLoc(); - Value a0 = rewriter.create( - loc, rewriter.getF64FloatAttr(0.543478)); - Value a1 = rewriter.create( - loc, rewriter.getF64FloatAttr(-0.456522)); - Value a2 = rewriter.create( - loc, rewriter.getF64FloatAttr(0.0)); + Value a0 = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(0.543478)); + Value a1 = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(-0.456522)); + Value a2 = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(0.0)); auto windowFunctionResult = windowFunctionImpl(binder, rewriter, size, a0, a1, a2, resultType, @@ -3147,29 +3183,32 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, "Input Tensor / attrs / resultType bind failed"); if (!binder.tensorOperandAtIndex(dftLength, 1)) { // Convert to int and pass as n - dftLength = rewriter.create( - binder.getLoc(), rewriter.getType(), dftLength); + dftLength = Torch::AtenItemOp::create( + rewriter, binder.getLoc(), rewriter.getType(), + dftLength); } else { // Default for torch is None - dftLength = rewriter.create(binder.getLoc()); + dftLength = Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); } // Default is same for onnx and torch if (!binder.tensorOperandAtIndex(axis, 2)) { // convert to int and pass to dims - axis = rewriter.create( - binder.getLoc(), rewriter.getType(), axis); + axis = Torch::AtenItemOp::create(rewriter, binder.getLoc(), + rewriter.getType(), + axis); } else { // Default in torch is -1 and onnx is -2 (since -1 is for real / img) - axis = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(-2)); + axis = Torch::ConstantIntOp::create(rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(-2)); } if (onesided == 1) return rewriter.notifyMatchFailure(binder.op, "Unsupported option : onesided"); // norm default string attr - Value norm = rewriter.create( - binder.getLoc(), rewriter.getStringAttr(Twine("backward"))); + Value norm = Torch::ConstantStrOp::create( + rewriter, binder.getLoc(), + rewriter.getStringAttr(Twine("backward"))); // Convert from [....., 2] complex number repr for fft consumption. Torch::ValueTensorType inType = binder.toValidTensorType(inTensor.getType()); @@ -3183,41 +3222,40 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( Value inForComplexVal = inTensor; ArrayRef inForComplexSizes = inType.getSizes().drop_back(); if (lastIndex == 1) { - Value constZeroVal = rewriter.create( - binder.getLoc(), rewriter.getF64FloatAttr(0)); - Value constOne = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); - Value constZero = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value constZeroVal = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getF64FloatAttr(0)); + Value constOne = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value constZero = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(0)); Value padSizeList = - rewriter - .create( - binder.getLoc(), - Torch::ListType::get(rewriter.getType()), - SmallVector({constZero, constOne})) + Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), + Torch::ListType::get(rewriter.getType()), + SmallVector({constZero, constOne})) .getResult(); - Value modeVal = rewriter.create( - binder.getLoc(), rewriter.getStringAttr("constant")); + Value modeVal = Torch::ConstantStrOp::create( + rewriter, binder.getLoc(), rewriter.getStringAttr("constant")); SmallVector resSize(inForComplexSizes); resSize.push_back(2); - inForComplexVal = rewriter.create( - binder.getLoc(), + inForComplexVal = Torch::AtenPadOp::create( + rewriter, binder.getLoc(), inType.getWithSizesAndDtype(resSize, inType.getOptionalDtype()), inTensor, padSizeList, modeVal, constZeroVal); } Type inComplexTensorType = Torch::ValueTensorType::get( binder.op->getContext(), inForComplexSizes, mlir::ComplexType::get(inType.getDtype())); - Value inComplexTensor = rewriter.create( - binder.getLoc(), inComplexTensorType, inForComplexVal); + Value inComplexTensor = Torch::AtenViewAsComplexOp::create( + rewriter, binder.getLoc(), inComplexTensorType, inForComplexVal); Value ftOp; if (inverse == 0) { - ftOp = rewriter.create( - binder.getLoc(), inComplexTensorType, inComplexTensor, + ftOp = Torch::AtenFftFftOp::create( + rewriter, binder.getLoc(), inComplexTensorType, inComplexTensor, /*n = */ dftLength, /*dim = */ axis, /*norm = */ norm); } else { - ftOp = rewriter.create( - binder.getLoc(), inComplexTensorType, inComplexTensor, + ftOp = Torch::AtenFftIfftOp::create( + rewriter, binder.getLoc(), inComplexTensorType, inComplexTensor, /*n = */ dftLength, /*dim = */ axis, /*norm = */ norm); } rewriter.replaceOpWithNewOp(binder.op, diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 1a0f7920ffcc..a131d4c6ce5c 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -43,32 +43,33 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( // HardSigmoid computes the following expression: // max(0, min(1, alpha * x + beta)) - Value constAlpha = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value constAlpha = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(alpha)); - Value constBeta = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value constBeta = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(beta)); // Expression: alpha * x + beta - Value alphaMulX = rewriter.create( - binder.getLoc(), resultType, tensorOperand, constAlpha); - Value constOne = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value alphaMulX = Torch::AtenMulScalarOp::create( + rewriter, binder.getLoc(), resultType, tensorOperand, constAlpha); + Value constOne = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(1.0)); - Value alphaMulXPlusBeta = rewriter.create( - binder.getLoc(), resultType, alphaMulX, constBeta, + Value alphaMulXPlusBeta = Torch::AtenAddScalarOp::create( + rewriter, binder.getLoc(), resultType, alphaMulX, constBeta, /*alpha=*/constOne); // Expression: min(1, alpha * x + beta) Value oneTensor = createRank0Tensor(rewriter, binder.getLoc(), resultType, constOne); - Value minExpression = rewriter.create( - binder.getLoc(), resultType, oneTensor, alphaMulXPlusBeta); + Value minExpression = + Torch::AtenMinimumOp::create(rewriter, binder.getLoc(), resultType, + oneTensor, alphaMulXPlusBeta); // Expression: max(0, min(1, alpha * x + beta)) - Value constZero = rewriter.create( - binder.getLoc(), rewriter.getF64FloatAttr(0.0)); + Value constZero = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getF64FloatAttr(0.0)); Value zeroTensor = createRank0Tensor(rewriter, binder.getLoc(), resultType, constZero); rewriter.replaceOpWithNewOp( @@ -86,8 +87,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.customOpNameStringAttr(approximate, "approximate", "none")) return failure(); - Value vApproximate = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value vApproximate = Torch::ConstantStrOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getStringAttr(approximate)); rewriter.replaceOpWithNewOp(binder.op, resultType, @@ -151,17 +152,17 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return rewriter.notifyMatchFailure(binder.op, "align_corners bind failure"); - Value interpolationMode = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value interpolationMode = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), iModeInt)); - Value paddingMode = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value paddingMode = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); bool alignMode = align; - Value alignCorners = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value alignCorners = Torch::ConstantBoolOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getBoolAttr(alignMode)); rewriter.replaceOpWithNewOp( @@ -184,11 +185,12 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return rewriter.notifyMatchFailure( binder.op, "condition must have one single element per " "https://onnx.ai/onnx/operators/onnx__If.html"); - auto conditionInt = rewriter.create( - binder.getLoc(), rewriter.getType(), + auto conditionInt = Torch::AtenItemOp::create( + rewriter, binder.getLoc(), rewriter.getType(), conditionTensor); - auto conditionBool = rewriter.create( - binder.getLoc(), rewriter.getType(), conditionInt); + auto conditionBool = Torch::AtenBoolIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), + conditionInt); llvm::SmallVector resultTypes; if (binder.tensorResultTypes(resultTypes)) { @@ -202,8 +204,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return rewriter.notifyMatchFailure(binder.op, "region bind failure"); } - auto primIfOp = rewriter.create( - binder.getLoc(), TypeRange(resultTypes), conditionBool); + auto primIfOp = Torch::PrimIfOp::create( + rewriter, binder.getLoc(), TypeRange(resultTypes), conditionBool); auto inlineIfCase = [&](Region &srcRegion, Region &dstRegion) { rewriter.inlineRegionBefore(srcRegion, dstRegion, dstRegion.begin()); @@ -230,8 +232,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( if (terOpRank != resRank) return failure(); if (terType != resultTypes[i]) { - Value cast = rewriter.create( - binder.getLoc(), resultTypes[i], terOperands[i]); + Value cast = Torch::TensorStaticInfoCastOp::create( + rewriter, binder.getLoc(), resultTypes[i], terOperands[i]); terOperands[i] = cast; } } @@ -309,20 +311,21 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( // MaxTripCount - tensor int64 scalar (or empty) Value maxTripCountTensor = operands[0]; - auto maxTripCountInt = rewriter.create( - binder.getLoc(), rewriter.getType(), + auto maxTripCountInt = Torch::AtenItemOp::create( + rewriter, binder.getLoc(), rewriter.getType(), maxTripCountTensor); // Condition - tensor bool scalar (or empty) Value conditionTensor = operands[1]; - auto conditionInt = rewriter.create( - binder.getLoc(), rewriter.getType(), + auto conditionInt = Torch::AtenItemOp::create( + rewriter, binder.getLoc(), rewriter.getType(), conditionTensor); - auto conditionBool = rewriter.create( - binder.getLoc(), rewriter.getType(), conditionInt); + auto conditionBool = Torch::AtenBoolIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), + conditionInt); // To be used for "for like" loop case - auto constBoolTrue = rewriter.create( - binder.getLoc(), rewriter.getBoolAttr(true)); + auto constBoolTrue = Torch::ConstantBoolOp::create( + rewriter, binder.getLoc(), rewriter.getBoolAttr(true)); // Others (if present) - variadic (can be tensors and scalar values) if (binder.getNumOperands() > 2) { @@ -366,8 +369,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( loopIsForLike ? constBoolTrue : conditionBool.getResult(); auto loc = binder.getLoc(); mlir::ImplicitLocOpBuilder b(loc, rewriter); - auto loop = b.create( - TypeRange(operandTypeVec), maxTripCountInt, loopInitCondition, + auto loop = Torch::PrimLoopOp::create( + b, TypeRange(operandTypeVec), maxTripCountInt, loopInitCondition, ValueRange(operands)); rewriter.cloneRegionBefore(*loopBodyIn, loop.getRegion(), @@ -427,11 +430,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } else { // Only use when loop is not forlike Value terminatorCondTensor = terminatorOperands[0]; - auto terminatorCondInt = rewriter.create( - binder.getLoc(), rewriter.getType(), + auto terminatorCondInt = Torch::AtenItemOp::create( + rewriter, binder.getLoc(), rewriter.getType(), terminatorCondTensor); - auto terminatorCondBool = rewriter.create( - binder.getLoc(), rewriter.getType(), + auto terminatorCondBool = Torch::AtenBoolIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), terminatorCondInt); terminatorCond = terminatorCondBool.getResult(); } @@ -453,9 +456,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( int64_t axis; if (binder.s64IntegerAttr(axis, "axis", -1)) return rewriter.notifyMatchFailure(binder.op, "axis bind failure"); - Value axisConst = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(axis)); - Value none = rewriter.create(binder.getLoc()); + Value axisConst = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(axis)); + Value none = Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); rewriter.replaceOpWithNewOp( binder.op, resultType, input, axisConst, none); return success(); @@ -486,11 +489,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return rewriter.notifyMatchFailure( binder.op, "failed to get input type or sizes"); - Value axisConst = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(axis)); - Value none = rewriter.create(binder.getLoc()); - Value cstEnd = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(rank - 1)); + Value axisConst = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(axis)); + Value none = Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); + Value cstEnd = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(rank - 1)); // The old version of LogSoftmax flattens post-axis dims, performs // LogSoftmax on the flattened dim, then unflattens back to the original @@ -508,8 +511,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( int64_t prodRightSizes = 1; llvm::SmallVector rightDimConsts; for (int64_t n : rightDims) { - rightDimConsts.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(n))); + rightDimConsts.push_back(Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(n))); if (n == Torch::kUnknownSize) { prodRightSizes = -1; break; @@ -518,19 +521,20 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } leftDims.push_back(prodRightSizes); // the following list will be used to unflatten the right side - Value rightDimsPrimList = rewriter.create( - binder.getLoc(), + Value rightDimsPrimList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), rewriter.getType( rewriter.getType()), rightDimConsts); auto flatRightTy = rewriter.getType( leftDims, inputTy.getOptionalDtype()); // flatten input - Value inputFlatRight = rewriter.create( - binder.getLoc(), flatRightTy, input, axisConst, cstEnd); + Value inputFlatRight = Torch::AtenFlattenUsingIntsOp::create( + rewriter, binder.getLoc(), flatRightTy, input, axisConst, cstEnd); // compute lsm over flattened index - Value outputFlatRight = rewriter.create( - binder.getLoc(), flatRightTy, inputFlatRight, axisConst, none); + Value outputFlatRight = Torch::AtenLogSoftmaxIntOp::create( + rewriter, binder.getLoc(), flatRightTy, inputFlatRight, axisConst, + none); // unflatten rewriter.replaceOpWithNewOp( binder.op, resultType, outputFlatRight, axisConst, @@ -560,14 +564,14 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return failure(); if (binder.tensorOperandAtIndex(lhsZp, 2)) { - lhsZp = rewriter.create( - binder.getLoc(), rewriter.getType(), + lhsZp = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); } if (binder.tensorOperandAtIndex(rhsZp, 3)) { - rhsZp = rewriter.create( - binder.getLoc(), rewriter.getType(), + rhsZp = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); } @@ -589,21 +593,21 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( mlir::IntegerType::Signed)); // Subtracting the zero_point values from lhs and rhs. - Value alpha = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); + Value alpha = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(1)); if (auto lhsZpTy = dyn_cast(lhsZp.getType())) - lhs = rewriter.create(loc, lhs.getType(), lhs, - lhsZp, alpha); + lhs = Torch::AtenSubTensorOp::create(rewriter, loc, lhs.getType(), + lhs, lhsZp, alpha); else - lhs = rewriter.create(loc, lhs.getType(), lhs, - lhsZp, alpha); + lhs = Torch::AtenSubScalarOp::create(rewriter, loc, lhs.getType(), + lhs, lhsZp, alpha); if (auto rhsZpTy = dyn_cast(rhsZp.getType())) - rhs = rewriter.create(loc, rhs.getType(), rhs, - rhsZp, alpha); + rhs = Torch::AtenSubTensorOp::create(rewriter, loc, rhs.getType(), + rhs, rhsZp, alpha); else - rhs = rewriter.create(loc, rhs.getType(), rhs, - rhsZp, alpha); + rhs = Torch::AtenSubScalarOp::create(rewriter, loc, rhs.getType(), + rhs, rhsZp, alpha); rewriter.replaceOpWithNewOp(binder.op, resultType, lhs, rhs); @@ -698,17 +702,17 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Type i1Ty = rewriter.getI1Type(); // Value constants - Value noneConst = b.create(); + Value noneConst = Torch::ConstantNoneOp::create(b); Value zeroConst = - b.create(rewriter.getI64IntegerAttr(0)); + Torch::ConstantIntOp::create(b, rewriter.getI64IntegerAttr(0)); Value oneConst = - b.create(rewriter.getI64IntegerAttr(1)); + Torch::ConstantIntOp::create(b, rewriter.getI64IntegerAttr(1)); Value twoConst = - b.create(rewriter.getI64IntegerAttr(2)); + Torch::ConstantIntOp::create(b, rewriter.getI64IntegerAttr(2)); Value int32DTypeConst = - b.create(rewriter.getI64IntegerAttr(3)); + Torch::ConstantIntOp::create(b, rewriter.getI64IntegerAttr(3)); Value float32DTypeConst = - b.create(rewriter.getI64IntegerAttr(6)); + Torch::ConstantIntOp::create(b, rewriter.getI64IntegerAttr(6)); Torch::ValueTensorType dftLenType = Torch::ValueTensorType::get(ctx, unranked, inpIntDType); @@ -717,10 +721,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Type freqBinsFltType = Torch::ValueTensorType::get(ctx, shapeNMB, f32Ty); - Value dftLengthDivTwoTensor = b.create( - dftLenType, operands[1], twoConst); - Value numSpectrogramBinsTensor = b.create( - dftLenType, dftLengthDivTwoTensor, oneConst, /*alpha =*/oneConst); + Value dftLengthDivTwoTensor = Torch::AtenFloorDivideScalarOp::create( + b, dftLenType, operands[1], twoConst); + Value numSpectrogramBinsTensor = + Torch::AtenAddScalarOp::create(b, dftLenType, dftLengthDivTwoTensor, + oneConst, /*alpha =*/oneConst); Value numSpectrogramBinsItem = getItemOp( binder, rewriter, numSpectrogramBinsTensor); @@ -728,201 +733,204 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( // https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_mel_weight_matrix.py#L25-L32 // convert input Freq Hz to Mel Value twoFiveNineFiveConst = - b.create(rewriter.getF64FloatAttr(2595)); + Torch::ConstantFloatOp::create(b, rewriter.getF64FloatAttr(2595)); Value sevenHConst = - b.create(rewriter.getF64FloatAttr(700)); + Torch::ConstantFloatOp::create(b, rewriter.getF64FloatAttr(700)); Value tenConst = - b.create(rewriter.getF64FloatAttr(10)); + Torch::ConstantFloatOp::create(b, rewriter.getF64FloatAttr(10)); Value oneFltConst = - b.create(rewriter.getF64FloatAttr(1)); - Value LnToLog10Const = b.create( - rewriter.getF64FloatAttr(M_LOG10E)); + Torch::ConstantFloatOp::create(b, rewriter.getF64FloatAttr(1)); + Value LnToLog10Const = Torch::ConstantFloatOp::create( + b, rewriter.getF64FloatAttr(M_LOG10E)); Value lfDiv7Hfloat = - b.create(lowerEdgeHzItem, sevenHConst); + Torch::AtenDivFloatOp::create(b, lowerEdgeHzItem, sevenHConst); Type freqType = Torch::ValueTensorType::get(ctx, unranked, inpFpDType); Value lfDiv7H = - b.create(freqType, lfDiv7Hfloat); - Value lfDiv7HAdd1 = b.create( - freqType, lfDiv7H, oneConst, /*alpha =*/oneConst); - Value lfDiv7HAdd1Ln = b.create(freqType, lfDiv7HAdd1); - Value lfDiv7HAdd1Log10 = b.create( - freqType, lfDiv7HAdd1Ln, LnToLog10Const); + Torch::PrimNumToTensorScalarOp::create(b, freqType, lfDiv7Hfloat); + Value lfDiv7HAdd1 = Torch::AtenAddScalarOp::create( + b, freqType, lfDiv7H, oneConst, /*alpha =*/oneConst); + Value lfDiv7HAdd1Ln = + Torch::AtenLogOp::create(b, freqType, lfDiv7HAdd1); + Value lfDiv7HAdd1Log10 = Torch::AtenMulScalarOp::create( + b, freqType, lfDiv7HAdd1Ln, LnToLog10Const); - Value lfMel = b.create( - freqType, lfDiv7HAdd1Log10, twoFiveNineFiveConst); + Value lfMel = Torch::AtenMulScalarOp::create( + b, freqType, lfDiv7HAdd1Log10, twoFiveNineFiveConst); Value hfDiv7Hfloat = - b.create(upperEdgeHzItem, sevenHConst); + Torch::AtenDivFloatOp::create(b, upperEdgeHzItem, sevenHConst); Value hfDiv7H = - b.create(freqType, hfDiv7Hfloat); - Value hfDiv7HAdd1 = b.create( - freqType, hfDiv7H, oneConst, /*alpha =*/oneConst); - Value hfDiv7HAdd1Ln = b.create(freqType, hfDiv7HAdd1); - Value hfDiv7HAdd1Log10 = b.create( - freqType, hfDiv7HAdd1Ln, LnToLog10Const); - - Value hfMel = b.create( - freqType, hfDiv7HAdd1Log10, twoFiveNineFiveConst); - - Value hfSubLf = b.create( - hfMel.getType(), hfMel, lfMel, /*alpha=*/oneConst); + Torch::PrimNumToTensorScalarOp::create(b, freqType, hfDiv7Hfloat); + Value hfDiv7HAdd1 = Torch::AtenAddScalarOp::create( + b, freqType, hfDiv7H, oneConst, /*alpha =*/oneConst); + Value hfDiv7HAdd1Ln = + Torch::AtenLogOp::create(b, freqType, hfDiv7HAdd1); + Value hfDiv7HAdd1Log10 = Torch::AtenMulScalarOp::create( + b, freqType, hfDiv7HAdd1Ln, LnToLog10Const); + + Value hfMel = Torch::AtenMulScalarOp::create( + b, freqType, hfDiv7HAdd1Log10, twoFiveNineFiveConst); + + Value hfSubLf = Torch::AtenSubTensorOp::create( + b, hfMel.getType(), hfMel, lfMel, /*alpha=*/oneConst); Value numMelBinsPlus2 = - b.create(numMelBinsItem, twoConst); - Value melStep = b.create( - hfSubLf.getType(), hfSubLf, numMelBinsPlus2); + Torch::AtenAddIntOp::create(b, numMelBinsItem, twoConst); + Value melStep = Torch::AtenDivScalarOp::create( + b, hfSubLf.getType(), hfSubLf, numMelBinsPlus2); - Value lowBinsInit = b.create( - freqBinsIntType, numMelBinsItem, /*dtype=*/int32DTypeConst, + Value lowBinsInit = Torch::AtenArangeOp::create( + b, freqBinsIntType, numMelBinsItem, /*dtype=*/int32DTypeConst, /*layout=*/noneConst, /*device=*/noneConst, /*pin_memory=*/noneConst); - Value centerBinsInit = b.create( - freqBinsIntType, numMelBinsItem, /*dtype=*/int32DTypeConst, + Value centerBinsInit = Torch::AtenArangeOp::create( + b, freqBinsIntType, numMelBinsItem, /*dtype=*/int32DTypeConst, /*layout=*/noneConst, /*device=*/noneConst, /*pin_memory=*/noneConst); - Value highBinsInit = b.create( - freqBinsIntType, numMelBinsItem, /*dtype=*/int32DTypeConst, + Value highBinsInit = Torch::AtenArangeOp::create( + b, freqBinsIntType, numMelBinsItem, /*dtype=*/int32DTypeConst, /*layout=*/noneConst, /*device=*/noneConst, /*pin_memory=*/noneConst); // Common values used in conversion - Value dftLenPlusOne = b.create( - dftLenType, operands[1], oneConst, /*alpha=*/oneConst); + Value dftLenPlusOne = Torch::AtenAddScalarOp::create( + b, dftLenType, operands[1], oneConst, /*alpha=*/oneConst); Value dftLenPlusOneItem = getItemOp(binder, rewriter, dftLenPlusOne); - Value falseConst = b.create(false); + Value falseConst = Torch::ConstantBoolOp::create(b, false); Torch::ValueTensorType unsqueezeBinsResType = Torch::ValueTensorType::get(ctx, shape1xNMB, si32Ty); // Low bins Mel to hz - Value lowBinsMulMelStep = b.create( - freqBinsFltType, lowBinsInit, melStep); - Value lowBinsScaled = b.create( - freqBinsFltType, lowBinsMulMelStep, lfMel, /*alpha=*/oneConst); - Value lbDiv = b.create( - freqBinsFltType, lowBinsScaled, twoFiveNineFiveConst); - Value lbClone = b.create( - freqBinsFltType, lowBinsScaled, /*memory_format=*/noneConst); - Value lbTenTensor = b.create( - freqBinsFltType, lbClone, tenConst); - Value lbPow = b.create( - freqBinsFltType, lbTenTensor, lbDiv); - Value lbPowSubOne = b.create( - freqBinsFltType, lbPow, oneConst, /*alpha=*/oneConst); - Value lowBinsHz = b.create( - freqBinsFltType, lbPowSubOne, sevenHConst); + Value lowBinsMulMelStep = Torch::AtenMulTensorOp::create( + b, freqBinsFltType, lowBinsInit, melStep); + Value lowBinsScaled = Torch::AtenAddTensorOp::create( + b, freqBinsFltType, lowBinsMulMelStep, lfMel, /*alpha=*/oneConst); + Value lbDiv = Torch::AtenDivScalarOp::create( + b, freqBinsFltType, lowBinsScaled, twoFiveNineFiveConst); + Value lbClone = Torch::AtenCloneOp::create( + b, freqBinsFltType, lowBinsScaled, /*memory_format=*/noneConst); + Value lbTenTensor = Torch::AtenFillScalarOp::create(b, freqBinsFltType, + lbClone, tenConst); + Value lbPow = Torch::AtenPowTensorTensorOp::create(b, freqBinsFltType, + lbTenTensor, lbDiv); + Value lbPowSubOne = Torch::AtenSubScalarOp::create( + b, freqBinsFltType, lbPow, oneConst, /*alpha=*/oneConst); + Value lowBinsHz = Torch::AtenMulScalarOp::create( + b, freqBinsFltType, lbPowSubOne, sevenHConst); // Normalize freqBinsHz - Value lbMulDft = b.create( - freqBinsFltType, lowBinsHz, dftLenPlusOneItem); - Value lowBinsNormalized = b.create( - freqBinsFltType, lbMulDft, sampleRateItem); + Value lbMulDft = Torch::AtenMulScalarOp::create( + b, freqBinsFltType, lowBinsHz, dftLenPlusOneItem); + Value lowBinsNormalized = Torch::AtenDivScalarOp::create( + b, freqBinsFltType, lbMulDft, sampleRateItem); // cast to int32 - Value lowBinsInt = b.create( - freqBinsIntType, lowBinsNormalized, /*dtype=*/int32DTypeConst, + Value lowBinsInt = Torch::AtenToDtypeOp::create( + b, freqBinsIntType, lowBinsNormalized, /*dtype=*/int32DTypeConst, /*non_blocking=*/falseConst, /*copy=*/falseConst, /*memory_format=*/noneConst); - Value lowBins = b.create( - unsqueezeBinsResType, lowBinsInt, /*dim=*/zeroConst); + Value lowBins = Torch::AtenUnsqueezeOp::create( + b, unsqueezeBinsResType, lowBinsInt, /*dim=*/zeroConst); // Center bins mel to hz - Value centerBinsInitInc = b.create( - freqBinsIntType, centerBinsInit, oneConst, /*alpha=*/oneConst); - Value centerBinsMulMelStep = b.create( - freqBinsFltType, centerBinsInitInc, melStep); - Value centerBinsScaled = b.create( - freqBinsFltType, centerBinsMulMelStep, lfMel, /*alpha=*/oneConst); - Value cbDiv = b.create( - freqBinsFltType, centerBinsScaled, twoFiveNineFiveConst); - Value cbClone = b.create( - freqBinsFltType, centerBinsScaled, /*memory_format=*/noneConst); - Value cbTenTensor = b.create( - freqBinsFltType, cbClone, tenConst); - Value cbPow = b.create( - freqBinsFltType, cbTenTensor, cbDiv); - Value cbPowSubOne = b.create( - freqBinsFltType, cbPow, oneConst, /*alpha=*/oneConst); - Value centerBinsHz = b.create( - freqBinsFltType, cbPowSubOne, sevenHConst); + Value centerBinsInitInc = Torch::AtenAddScalarOp::create( + b, freqBinsIntType, centerBinsInit, oneConst, /*alpha=*/oneConst); + Value centerBinsMulMelStep = Torch::AtenMulTensorOp::create( + b, freqBinsFltType, centerBinsInitInc, melStep); + Value centerBinsScaled = Torch::AtenAddTensorOp::create( + b, freqBinsFltType, centerBinsMulMelStep, lfMel, + /*alpha=*/oneConst); + Value cbDiv = Torch::AtenDivScalarOp::create( + b, freqBinsFltType, centerBinsScaled, twoFiveNineFiveConst); + Value cbClone = Torch::AtenCloneOp::create( + b, freqBinsFltType, centerBinsScaled, /*memory_format=*/noneConst); + Value cbTenTensor = Torch::AtenFillScalarOp::create(b, freqBinsFltType, + cbClone, tenConst); + Value cbPow = Torch::AtenPowTensorTensorOp::create(b, freqBinsFltType, + cbTenTensor, cbDiv); + Value cbPowSubOne = Torch::AtenSubScalarOp::create( + b, freqBinsFltType, cbPow, oneConst, /*alpha=*/oneConst); + Value centerBinsHz = Torch::AtenMulScalarOp::create( + b, freqBinsFltType, cbPowSubOne, sevenHConst); // Normalize freqBinsHz - Value cbMulDft = b.create( - freqBinsFltType, centerBinsHz, dftLenPlusOneItem); - Value centerBinsNormalized = b.create( - freqBinsFltType, cbMulDft, sampleRateItem); + Value cbMulDft = Torch::AtenMulScalarOp::create( + b, freqBinsFltType, centerBinsHz, dftLenPlusOneItem); + Value centerBinsNormalized = Torch::AtenDivScalarOp::create( + b, freqBinsFltType, cbMulDft, sampleRateItem); // cast to int32 - Value centerBinsInt = b.create( - freqBinsIntType, centerBinsNormalized, /*dtype=*/int32DTypeConst, + Value centerBinsInt = Torch::AtenToDtypeOp::create( + b, freqBinsIntType, centerBinsNormalized, /*dtype=*/int32DTypeConst, /*non_blocking=*/falseConst, /*copy=*/falseConst, /*memory_format=*/noneConst); - Value centerBins = b.create( - unsqueezeBinsResType, centerBinsInt, /*dim=*/zeroConst); + Value centerBins = Torch::AtenUnsqueezeOp::create( + b, unsqueezeBinsResType, centerBinsInt, /*dim=*/zeroConst); // High bins mel to hz - Value highBinsInitInc = b.create( - freqBinsIntType, highBinsInit, twoConst, /*alpha=*/oneConst); - Value highBinsMulMelStep = b.create( - freqBinsFltType, highBinsInitInc, melStep); - Value highBinsScaled = b.create( - freqBinsFltType, highBinsMulMelStep, lfMel, /*alpha=*/oneConst); - Value hbDiv = b.create( - freqBinsFltType, highBinsScaled, twoFiveNineFiveConst); - Value hbClone = b.create( - freqBinsFltType, highBinsScaled, /*memory_format=*/noneConst); - Value hbTenTensor = b.create( - freqBinsFltType, hbClone, tenConst); - Value hbPow = b.create( - freqBinsFltType, hbTenTensor, hbDiv); - Value hbPowSubOne = b.create( - freqBinsFltType, hbPow, oneConst, /*alpha=*/oneConst); - Value highBinsHz = b.create( - freqBinsFltType, hbPowSubOne, sevenHConst); + Value highBinsInitInc = Torch::AtenAddScalarOp::create( + b, freqBinsIntType, highBinsInit, twoConst, /*alpha=*/oneConst); + Value highBinsMulMelStep = Torch::AtenMulTensorOp::create( + b, freqBinsFltType, highBinsInitInc, melStep); + Value highBinsScaled = Torch::AtenAddTensorOp::create( + b, freqBinsFltType, highBinsMulMelStep, lfMel, /*alpha=*/oneConst); + Value hbDiv = Torch::AtenDivScalarOp::create( + b, freqBinsFltType, highBinsScaled, twoFiveNineFiveConst); + Value hbClone = Torch::AtenCloneOp::create( + b, freqBinsFltType, highBinsScaled, /*memory_format=*/noneConst); + Value hbTenTensor = Torch::AtenFillScalarOp::create(b, freqBinsFltType, + hbClone, tenConst); + Value hbPow = Torch::AtenPowTensorTensorOp::create(b, freqBinsFltType, + hbTenTensor, hbDiv); + Value hbPowSubOne = Torch::AtenSubScalarOp::create( + b, freqBinsFltType, hbPow, oneConst, /*alpha=*/oneConst); + Value highBinsHz = Torch::AtenMulScalarOp::create( + b, freqBinsFltType, hbPowSubOne, sevenHConst); // Normalize freqBinsHz - Value hbMulDft = b.create( - freqBinsFltType, highBinsHz, dftLenPlusOneItem); - Value highBinsNormalized = b.create( - freqBinsFltType, hbMulDft, sampleRateItem); + Value hbMulDft = Torch::AtenMulScalarOp::create( + b, freqBinsFltType, highBinsHz, dftLenPlusOneItem); + Value highBinsNormalized = Torch::AtenDivScalarOp::create( + b, freqBinsFltType, hbMulDft, sampleRateItem); // cast to int32 - Value highBinsInt = b.create( - freqBinsIntType, highBinsNormalized, /*dtype=*/int32DTypeConst, + Value highBinsInt = Torch::AtenToDtypeOp::create( + b, freqBinsIntType, highBinsNormalized, /*dtype=*/int32DTypeConst, /*non_blocking=*/falseConst, /*copy=*/falseConst, /*memory_format=*/noneConst); - Value highBins = b.create( - unsqueezeBinsResType, highBinsInt, /*dim=*/zeroConst); + Value highBins = Torch::AtenUnsqueezeOp::create( + b, unsqueezeBinsResType, highBinsInt, /*dim=*/zeroConst); Type iotaInitType = inputIntType.getWithSizesAndDtype(shapeNSB, si32Ty); - Value iotaInit = b.create( - iotaInitType, numSpectrogramBinsItem, + Value iotaInit = Torch::AtenArangeOp::create( + b, iotaInitType, numSpectrogramBinsItem, /*dtype=*/int32DTypeConst, /*layout=*/noneConst, /*device=*/noneConst, /*pin_memory=*/noneConst); Torch::ValueTensorType unsqueezeIotaResType = Torch::ValueTensorType::get(ctx, shapeNSBx1, si32Ty); - Value iota = b.create( - unsqueezeIotaResType, iotaInit, /*dim=*/oneConst); + Value iota = Torch::AtenUnsqueezeOp::create(b, unsqueezeIotaResType, + iotaInit, /*dim=*/oneConst); - Value lowToCenter = b.create( - unsqueezeBinsResType, centerBins, lowBins, /*alpha=*/oneConst); - Value centerToHigh = b.create( - unsqueezeBinsResType, highBins, centerBins, /*alpha=*/oneConst); + Value lowToCenter = Torch::AtenSubTensorOp::create( + b, unsqueezeBinsResType, centerBins, lowBins, /*alpha=*/oneConst); + Value centerToHigh = Torch::AtenSubTensorOp::create( + b, unsqueezeBinsResType, highBins, centerBins, /*alpha=*/oneConst); Value oneConstTensor = Torch::createRank0Tensor( rewriter, binder.getLoc(), Torch::ValueTensorType::get(ctx, std::nullopt, f32Ty), oneConst); Type scaledType = inputIntType.getWithSizesAndDtype(shape1xNMB, f32Ty); - Value upscaleInit = b.create( - unsqueezeBinsResType, oneConstTensor, lowToCenter); - Value upscale = b.create( - scaledType, upscaleInit, /*dtype=*/float32DTypeConst, + Value upscaleInit = Torch::AtenMaximumOp::create( + b, unsqueezeBinsResType, oneConstTensor, lowToCenter); + Value upscale = Torch::AtenToDtypeOp::create( + b, scaledType, upscaleInit, /*dtype=*/float32DTypeConst, /*non_blocking=*/falseConst, /*copy=*/falseConst, /*memory_format=*/noneConst); - Value downscaleInit = b.create( - unsqueezeBinsResType, oneConstTensor, centerToHigh); - Value downscale = b.create( - scaledType, downscaleInit, /*dtype=*/float32DTypeConst, + Value downscaleInit = Torch::AtenMaximumOp::create( + b, unsqueezeBinsResType, oneConstTensor, centerToHigh); + Value downscale = Torch::AtenToDtypeOp::create( + b, scaledType, downscaleInit, /*dtype=*/float32DTypeConst, /*non_blocking=*/falseConst, /*copy=*/falseConst, /*memory_format=*/noneConst); @@ -931,23 +939,23 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Torch::ValueTensorType diffFloatType = Torch::ValueTensorType::get(ctx, shapeNSBxNMB, f32Ty); - Value iotaSubLBInt = b.create( - binsDiffType, iota, lowBins, /*alpha=*/oneConst); - Value iotaSubLB = b.create( - diffFloatType, iotaSubLBInt, /*dtype=*/float32DTypeConst, + Value iotaSubLBInt = Torch::AtenSubTensorOp::create( + b, binsDiffType, iota, lowBins, /*alpha=*/oneConst); + Value iotaSubLB = Torch::AtenToDtypeOp::create( + b, diffFloatType, iotaSubLBInt, /*dtype=*/float32DTypeConst, /*non_blocking=*/falseConst, /*copy=*/falseConst, /*memory_format=*/noneConst); - Value rampUp = - b.create(diffFloatType, iotaSubLB, upscale); + Value rampUp = Torch::AtenDivTensorOp::create(b, diffFloatType, + iotaSubLB, upscale); - Value hbSubIotaInt = b.create( - binsDiffType, highBins, iota, /*alpha=*/oneConst); - Value hbSubIota = b.create( - diffFloatType, hbSubIotaInt, /*dtype=*/float32DTypeConst, + Value hbSubIotaInt = Torch::AtenSubTensorOp::create( + b, binsDiffType, highBins, iota, /*alpha=*/oneConst); + Value hbSubIota = Torch::AtenToDtypeOp::create( + b, diffFloatType, hbSubIotaInt, /*dtype=*/float32DTypeConst, /*non_blocking=*/falseConst, /*copy=*/falseConst, /*memory_format=*/noneConst); - Value rampDown = b.create(diffFloatType, - hbSubIota, downscale); + Value rampDown = Torch::AtenDivTensorOp::create(b, diffFloatType, + hbSubIota, downscale); // ramp values Type iotaCmpBinsType = @@ -955,38 +963,38 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( // Iota Cmp Bins Value iotaGtEqCBins = - b.create(iotaCmpBinsType, iota, centerBins); + Torch::AtenGeTensorOp::create(b, iotaCmpBinsType, iota, centerBins); Value iotaEqCBins = - b.create(iotaCmpBinsType, iota, centerBins); + Torch::AtenEqTensorOp::create(b, iotaCmpBinsType, iota, centerBins); Value iotaLtLBins = - b.create(iotaCmpBinsType, iota, lowBins); + Torch::AtenLtTensorOp::create(b, iotaCmpBinsType, iota, lowBins); Value iotaGtLBins = - b.create(iotaCmpBinsType, iota, highBins); + Torch::AtenGtTensorOp::create(b, iotaCmpBinsType, iota, highBins); // Create output freq ramps Low-Center-High Type rampInitType = inputIntType.getWithSizesAndDtype(shapeNSBxNMB, f32Ty); - Value rampInit = b.create( - rampInitType, iotaGtEqCBins, rampDown, rampUp); - Value rampInitLt = b.create( - rampInitType, iotaLtLBins, zeroConst, rampInit); - Value rampInitLtGt = b.create( - rampInitType, iotaGtLBins, zeroConst, rampInitLt); + Value rampInit = Torch::AtenWhereSelfOp::create( + b, rampInitType, iotaGtEqCBins, rampDown, rampUp); + Value rampInitLt = Torch::AtenWhereScalarSelfOp::create( + b, rampInitType, iotaLtLBins, zeroConst, rampInit); + Value rampInitLtGt = Torch::AtenWhereScalarSelfOp::create( + b, rampInitType, iotaGtLBins, zeroConst, rampInitLt); Type C2HCmpBinsType = inputIntType.getWithSizesAndDtype(shape1xNMB, i1Ty); - Value C2HEqZero = b.create( - C2HCmpBinsType, centerToHigh, zeroConst); - Value cornerCases = b.create( - iotaCmpBinsType, iotaEqCBins, C2HEqZero); - Value rampOutput = b.create( - rampInitType, cornerCases, oneFltConst, rampInitLtGt); - - Value outputDTypeConst = b.create( - rewriter.getType(), + Value C2HEqZero = Torch::AtenEqScalarOp::create( + b, C2HCmpBinsType, centerToHigh, zeroConst); + Value cornerCases = Torch::AtenLogicalAndOp::create( + b, iotaCmpBinsType, iotaEqCBins, C2HEqZero); + Value rampOutput = Torch::AtenWhereScalarSelfOp::create( + b, rampInitType, cornerCases, oneFltConst, rampInitLtGt); + + Value outputDTypeConst = Torch::ConstantIntOp::create( + b, rewriter.getType(), rewriter.getI64IntegerAttr(torchDTypeInt.value())); - Value finalOutput = b.create( - resultType, rampOutput, /*dtype=*/outputDTypeConst, + Value finalOutput = Torch::AtenToDtypeOp::create( + b, resultType, rampOutput, /*dtype=*/outputDTypeConst, /*non_blocking=*/falseConst, /*copy=*/falseConst, /*memory_format=*/noneConst); @@ -1027,16 +1035,17 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( "unimplemented support for the given dtype conversion"); } - Value torchDtypeIntValue = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(torchDtype.value())); - Value numSamples = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(sampleSize)); + Value torchDtypeIntValue = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(torchDtype.value())); + Value numSamples = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(sampleSize)); // PRG is seeded globally by default - Value none = rewriter.create(binder.getLoc()); + Value none = Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); // Sample with replacement by default (no onnx equivalent in arguments) - Value cstTrue = rewriter.create( - binder.getLoc(), rewriter.getBoolAttr(true)); + Value cstTrue = Torch::ConstantBoolOp::create( + rewriter, binder.getLoc(), rewriter.getBoolAttr(true)); // Torch Multinomial always produces a LongTensor Torch::ValueTensorType selfType = @@ -1048,12 +1057,12 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Torch::ValueTensorType multinomialOutputType = Torch::ValueTensorType::get(selfType.getContext(), outShapes, int64Dtype); - Value multinomialTensor = rewriter.create( - binder.getLoc(), multinomialOutputType, self, numSamples, cstTrue, - none); + Value multinomialTensor = Torch::AtenMultinomialOp::create( + rewriter, binder.getLoc(), multinomialOutputType, self, numSamples, + cstTrue, none); - Value cstFalse = rewriter.create( - binder.getLoc(), rewriter.getBoolAttr(false)); + Value cstFalse = Torch::ConstantBoolOp::create( + rewriter, binder.getLoc(), rewriter.getBoolAttr(false)); rewriter.replaceOpWithNewOp( binder.op, resultType, multinomialTensor, torchDtypeIntValue, cstFalse, cstFalse, none); @@ -1078,22 +1087,23 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( // optional third tensor argument if (binder.tensorOperandAtIndex(weight, 2)) { - weight = rewriter.create(binder.getLoc()); + weight = Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); } - ignore_index = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(ignore_index_int)); + ignore_index = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(ignore_index_int)); // convert string reduction attr to standardized integer enum value int reduction_value = torch_upstream::get_loss_reduction_enum(reduction_str); - reduction = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(reduction_value)); + reduction = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(reduction_value)); - Value nllLoss = rewriter - .create( - binder.getLoc(), resultType, resultType, self, - target, weight, reduction, ignore_index) + Value nllLoss = Torch::AtenNllLossForwardOp::create( + rewriter, binder.getLoc(), resultType, resultType, + self, target, weight, reduction, ignore_index) ->getResult(0); rewriter.replaceOp(binder.op, nllLoss); @@ -1107,18 +1117,18 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.tensorResultType(resultType)) { return failure(); } - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value zero = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); - Value one = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value one = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); auto rawSize = resultType.getSizes(); SmallVector torchResultSize(rawSize.rbegin(), rawSize.rend()); auto torchResultType = rewriter.getType( torchResultSize, resultType.getDtype()); - auto nonZero = rewriter.create( - binder.getLoc(), torchResultType, operand); + auto nonZero = Torch::AtenNonzeroOp::create(rewriter, binder.getLoc(), + torchResultType, operand); // The output tensor has a shape of ((n, z)), where (n) is the // number of dimensions in the input tensor and (z) is the // number of non-zero elements2. This is different from @@ -1244,21 +1254,22 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( createConstantIntList(binder, rewriter, shuffledPadding); Value zero; if (isa(resultTypeOut.getDtype())) { - zero = rewriter.create( - binder.getLoc(), rewriter.getType(), + zero = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr( std::numeric_limits::lowest())); } else if (isa(resultTypeOut.getDtype())) { - zero = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr( - std::numeric_limits::lowest())); + zero = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr( + std::numeric_limits::lowest())); } auto paddedInputTy = rewriter.getType( paddedShape, operandTy.getDtype()); - operand = rewriter.create( - binder.getLoc(), paddedInputTy, operand, shuffledPaddingList, - zero); + operand = Torch::AtenConstantPadNdOp::create( + rewriter, binder.getLoc(), paddedInputTy, operand, + shuffledPaddingList, zero); padding.clear(); padding.resize(spatial, 0); } @@ -1269,7 +1280,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Value dilationsList = createConstantIntList(binder, rewriter, dilations); Value cstCeilMode = - rewriter.create(binder.getLoc(), ceilMode); + Torch::ConstantBoolOp::create(rewriter, binder.getLoc(), ceilMode); if (binder.op->getNumResults() == 2) { Torch::ValueTensorType resultTypeIndices; @@ -1353,11 +1364,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( auto floatTy = roisTy.getDtype(); auto torchIntTy = rewriter.getType(); - Value spatialScaleValue = rewriter.create( - loc, rewriter.getF64FloatAttr(spatialScale)); + Value spatialScaleValue = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(spatialScale)); - Value boolTrue = rewriter.create( - loc, rewriter.getBoolAttr(true)); + Value boolTrue = Torch::ConstantBoolOp::create( + rewriter, loc, rewriter.getBoolAttr(true)); ArrayRef inputShape = inputTy.getSizes(); int64_t inputRank = inputShape.size(); @@ -1431,53 +1442,53 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( SmallVector constInts(6); for (int i = 0; i <= 5; i++) { - constInts[i] = rewriter.create( - loc, rewriter.getI64IntegerAttr(i)); + constInts[i] = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(i)); } int64_t widthDim = inputRank - 2; - Value widthDimValue = rewriter.create( - loc, rewriter.getI64IntegerAttr(widthDim)); + Value widthDimValue = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(widthDim)); int64_t heightDim = inputRank - 3; - Value heightDimValue = rewriter.create( - loc, rewriter.getI64IntegerAttr(heightDim)); + Value heightDimValue = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(heightDim)); // extract indices of images within batch auto batchIdxsShape = SmallVector{Torch::kUnknownSize}; auto batchIdxsFloatTy = rewriter.getType(batchIdxsShape, floatTy); - Value batchIdxsFloat = rewriter.create( - loc, batchIdxsFloatTy, rois, constInts[1], constInts[0]); + Value batchIdxsFloat = Torch::AtenSelectIntOp::create( + rewriter, loc, batchIdxsFloatTy, rois, constInts[1], constInts[0]); auto batchIdxsIntTy = rewriter.getType(batchIdxsShape, intTy); - Value batchIdxs = rewriter.create( - loc, batchIdxsIntTy, batchIdxsFloat, boolTrue); + Value batchIdxs = Torch::Aten_CastLongOp::create( + rewriter, loc, batchIdxsIntTy, batchIdxsFloat, boolTrue); // extract scaled ranges for regions of interest auto roiBBsShape = SmallVector{Torch::kUnknownSize, 4}; auto roiBBsFloatTy = rewriter.getType(roiBBsShape, floatTy); - Value roiBBs = rewriter.create( - loc, roiBBsFloatTy, rois, constInts[1], constInts[1], constInts[5], - constInts[1]); - Value roiBBsScaledFloat = rewriter.create( - loc, roiBBsFloatTy, roiBBs, spatialScaleValue); + Value roiBBs = Torch::AtenSliceTensorOp::create( + rewriter, loc, roiBBsFloatTy, rois, constInts[1], constInts[1], + constInts[5], constInts[1]); + Value roiBBsScaledFloat = Torch::AtenMulScalarOp::create( + rewriter, loc, roiBBsFloatTy, roiBBs, spatialScaleValue); auto roiBBsTy = rewriter.getType(roiBBsShape, intTy); - Value roiBBsScaled = rewriter.create( - loc, roiBBsTy, roiBBsScaledFloat, boolTrue); + Value roiBBsScaled = Torch::Aten_CastLongOp::create( + rewriter, loc, roiBBsTy, roiBBsScaledFloat, boolTrue); SmallVector pooledRois; for (int64_t i = 0; i < numRois; i++) { - Value roiIdx = rewriter.create( - loc, rewriter.getI64IntegerAttr(i)); + Value roiIdx = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(i)); auto roiSpecTy = rewriter.getType( roiBBsTy.getSizes().slice(1), intTy); - Value roiSpec = rewriter.create( - loc, roiSpecTy, roiBBsScaled, constInts[0], roiIdx); + Value roiSpec = Torch::AtenSelectIntOp::create( + rewriter, loc, roiSpecTy, roiBBsScaled, constInts[0], roiIdx); // Load individual ROI specification values SmallVector roiValues(5); @@ -1486,15 +1497,16 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( SmallVector{}, intTy); Value specTensor; if (specIdx == 0) { // batch index - specTensor = rewriter.create( - loc, intEmptyTensorTy, batchIdxs, constInts[0], roiIdx); + specTensor = Torch::AtenSelectIntOp::create( + rewriter, loc, intEmptyTensorTy, batchIdxs, constInts[0], + roiIdx); } else { // roi dimension - specTensor = rewriter.create( - loc, intEmptyTensorTy, roiSpec, constInts[0], + specTensor = Torch::AtenSelectIntOp::create( + rewriter, loc, intEmptyTensorTy, roiSpec, constInts[0], constInts[specIdx - 1]); } - Value specValue = - rewriter.create(loc, torchIntTy, specTensor); + Value specValue = Torch::AtenItemOp::create(rewriter, loc, + torchIntTy, specTensor); roiValues[specIdx] = specValue; } Value batchIdx = roiValues[0], roiX1 = roiValues[1], @@ -1502,15 +1514,16 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( roiY2 = roiValues[4]; // add 1 to make range ends inclusive as per ONNX implementation - roiX2 = rewriter.create(loc, torchIntTy, roiX2, - constInts[1]); - roiY2 = rewriter.create(loc, torchIntTy, roiY2, - constInts[1]); + roiX2 = Torch::AtenAddOp::create(rewriter, loc, torchIntTy, roiX2, + constInts[1]); + roiY2 = Torch::AtenAddOp::create(rewriter, loc, torchIntTy, roiY2, + constInts[1]); auto imageTy = rewriter.getType( inputShape.slice(1), inputTy.getDtype()); - Value image = rewriter.create( - loc, imageTy, input, constInts[0], batchIdx); // (NC x H x W) + Value image = Torch::AtenSelectIntOp::create( + rewriter, loc, imageTy, input, constInts[0], + batchIdx); // (NC x H x W) SmallVector imageUnknownShape(imageTy.getSizes()); imageUnknownShape[heightDim] = Torch::kUnknownSize; @@ -1519,12 +1532,12 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( imageUnknownShape, imageTy.getDtype()); // extract ROI from image - Value imageExtractedY = rewriter.create( - loc, imageUnknownTy, image, heightDimValue, roiY1, roiY2, - constInts[1]); - Value region = rewriter.create( - loc, imageUnknownTy, imageExtractedY, widthDimValue, roiX1, roiX2, - constInts[1]); + Value imageExtractedY = Torch::AtenSliceTensorOp::create( + rewriter, loc, imageUnknownTy, image, heightDimValue, roiY1, + roiY2, constInts[1]); + Value region = Torch::AtenSliceTensorOp::create( + rewriter, loc, imageUnknownTy, imageExtractedY, widthDimValue, + roiX1, roiX2, constInts[1]); SmallVector pooledRegionShape(imageTy.getSizes()); pooledRegionShape[heightDim] = pooledShape[0]; @@ -1536,16 +1549,16 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( // apply pooling on ROI Value pooledRegion = - rewriter - .create( - loc, pooledRegionTy, pooledRegionIndicesTy, region, - outputShapeList) + Torch::AtenAdaptiveMaxPool2dOp::create( + rewriter, loc, pooledRegionTy, pooledRegionIndicesTy, region, + outputShapeList) .getResult0(); pooledRois.push_back(pooledRegion); } - Value pooledRoisList = rewriter.create( - loc, Torch::ListType::get(pooledRois[0].getType()), pooledRois); + Value pooledRoisList = Torch::PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(pooledRois[0].getType()), + pooledRois); rewriter.replaceOpWithNewOp( binder.op, resultTy, pooledRoisList, constInts[0]); @@ -1587,16 +1600,16 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.f32FloatAttr(eps, "epsilon", 1e-05f)) { return failure(); } - Value none = rewriter.create(binder.getLoc()); + Value none = Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); Value boolTrue = - rewriter.create(binder.getLoc(), true); + Torch::ConstantBoolOp::create(rewriter, binder.getLoc(), true); Value boolFalse = - rewriter.create(binder.getLoc(), false); - auto epsValue = rewriter.create( - binder.getLoc(), rewriter.getF64FloatAttr(eps)); + Torch::ConstantBoolOp::create(rewriter, binder.getLoc(), false); + auto epsValue = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getF64FloatAttr(eps)); - auto momentum = rewriter.create( - binder.getLoc(), rewriter.getF64FloatAttr(0.0f)); + auto momentum = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getF64FloatAttr(0.0f)); rewriter.replaceOpWithNewOp( binder.op, resultType, /* input */ operands[0], /* weight */ operands[1], @@ -1629,9 +1642,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( int64_t inputRank = inputTy.getSizes().size(); Location loc = binder.getLoc(); - Value keepDim = rewriter.create(loc, true); - Value unBiased = rewriter.create(loc, false); - Value none = rewriter.create(loc); + Value keepDim = Torch::ConstantBoolOp::create(rewriter, loc, true); + Value unBiased = Torch::ConstantBoolOp::create(rewriter, loc, false); + Value none = Torch::ConstantNoneOp::create(rewriter, loc); ArrayRef output_shape = resultType.getSizes(); SmallVector reduced_shape(output_shape); @@ -1647,29 +1660,29 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( resultType.getContext(), reduced_shape, resultType.getDtype()); SmallVector cstAxes; for (int64_t i : axes) { - cstAxes.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(i))); + cstAxes.push_back(Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(i))); } - Value axes_list = rewriter.create( - loc, + Value axes_list = Torch::PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstAxes); - Value mean = rewriter.create( - loc, reducedOutTy, input, axes_list, keepDim, none); - Value variance = rewriter.create( - loc, reducedOutTy, input, axes_list, unBiased, keepDim); - Value cstOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - Value cstEps = rewriter.create( - loc, rewriter.getF64FloatAttr(1e-9)); - variance = rewriter.create( - loc, reducedOutTy, variance, cstEps, cstOne); + Value mean = Torch::AtenMeanDimOp::create( + rewriter, loc, reducedOutTy, input, axes_list, keepDim, none); + Value variance = Torch::AtenVarDimOp::create( + rewriter, loc, reducedOutTy, input, axes_list, unBiased, keepDim); + Value cstOne = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(1)); + Value cstEps = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(1e-9)); + variance = Torch::AtenAddScalarOp::create(rewriter, loc, reducedOutTy, + variance, cstEps, cstOne); Value sqrtVar = - rewriter.create(loc, reducedOutTy, variance); - Value inputMinusMean = rewriter.create( - loc, resultType, input, mean, cstOne); - Value meanVarNorm = rewriter.create( - loc, resultType, inputMinusMean, sqrtVar); + Torch::AtenSqrtOp::create(rewriter, loc, reducedOutTy, variance); + Value inputMinusMean = Torch::AtenSubTensorOp::create( + rewriter, loc, resultType, input, mean, cstOne); + Value meanVarNorm = Torch::AtenDivTensorOp::create( + rewriter, loc, resultType, inputMinusMean, sqrtVar); rewriter.replaceOp(binder.op, meanVarNorm); return success(); @@ -1684,8 +1697,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } Value result = operands[0]; for (uint64_t i = 1; i < operands.size(); i++) { - result = rewriter.create( - binder.getLoc(), resultType, result, operands[i]); + result = Torch::AtenMaximumOp::create( + rewriter, binder.getLoc(), resultType, result, operands[i]); } rewriter.replaceOp(binder.op, result); return success(); @@ -1700,8 +1713,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } Value result = operands[0]; for (uint64_t i = 1; i < operands.size(); i++) { - result = rewriter.create( - binder.getLoc(), resultType, result, operands[i]); + result = Torch::AtenMinimumOp::create( + rewriter, binder.getLoc(), resultType, result, operands[i]); } rewriter.replaceOp(binder.op, result); return success(); @@ -1736,14 +1749,14 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( auto ty = rewriter.getType( operandTy.getSizes(), i1ty); auto torchqTy = Torch::getScalarTypeForType(i1ty); - Value tyConst = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value tyConst = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), static_cast(torchqTy))); - Value none = rewriter.create(loc); - Value cstFalse = rewriter.create(loc, false); - operand = rewriter.create( - loc, ty, operand, tyConst, + Value none = Torch::ConstantNoneOp::create(rewriter, loc); + Value cstFalse = Torch::ConstantBoolOp::create(rewriter, loc, false); + operand = Torch::AtenToDtypeOp::create( + rewriter, loc, ty, operand, tyConst, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); } @@ -1809,8 +1822,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( SmallVector batchDims; SmallVector dataDims; for (int64_t i = 0; i < dataRank; ++i) { - Value k = rewriter.create(binder.getLoc(), i); - Value dataDim = rewriter.create(loc, data, k); + Value k = Torch::ConstantIntOp::create(rewriter, binder.getLoc(), i); + Value dataDim = Torch::AtenSizeIntOp::create(rewriter, loc, data, k); dataDims.push_back(dataDim); if (i < batchDimCount) { batchShape.push_back(dataShape[i]); @@ -1819,22 +1832,22 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } // step 3. Get dimension list of indices. - Value constZero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - Value constOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); + Value constZero = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(0)); + Value constOne = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(1)); SmallVector indicesDimsMinusOne; SmallVector unflattenIndicesDims; Value indicesFlattenDim = constOne; for (int64_t i = 0; i < indicesRank - 1; ++i) { - Value k = rewriter.create(binder.getLoc(), i); + Value k = Torch::ConstantIntOp::create(rewriter, binder.getLoc(), i); Value indicesDim = - rewriter.create(loc, indices, k); + Torch::AtenSizeIntOp::create(rewriter, loc, indices, k); indicesDimsMinusOne.push_back(indicesDim); if (i >= batchDimCount) { unflattenIndicesDims.push_back(indicesDim); - indicesFlattenDim = rewriter.create( - loc, indicesFlattenDim, indicesDim); + indicesFlattenDim = Torch::AtenMulIntOp::create( + rewriter, loc, indicesFlattenDim, indicesDim); } } ArrayRef indicesShapeMinusOne = indicesShape.drop_back(); @@ -1860,8 +1873,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( // step 4. Convert indices_shape[-1] dimensional indexing to 1D // indexing. - Value sliceDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(indicesRank - 1)); + Value sliceDim = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(indicesRank - 1)); SmallVector indicesSliceShape(indicesShapeMinusOne); indicesSliceShape.push_back(1); auto indicesSliceTy = rewriter.getType( @@ -1870,28 +1883,29 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Value start = constZero; Value updatedIndices; for (int64_t i = 0; i < indicesLastDim; ++i) { - Value end = rewriter.create( - loc, rewriter.getI64IntegerAttr(i + 1)); - Value indicesSlice = rewriter.create( - loc, indicesSliceTy, indices, sliceDim, start, end, + Value end = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(i + 1)); + Value indicesSlice = Torch::AtenSliceTensorOp::create( + rewriter, loc, indicesSliceTy, indices, sliceDim, start, end, /*step=*/constOne); start = end; // Apply bounds checking on the indices slice. auto boolTy = rewriter.getType( indicesSliceShape, rewriter.getI1Type()); - Value lt = rewriter.create( - loc, boolTy, indicesSlice, constZero); - Value add = rewriter.create( - loc, indicesSliceTy, indicesSlice, dataDims[batchDimCount + i], + Value lt = Torch::AtenLtScalarOp::create(rewriter, loc, boolTy, + indicesSlice, constZero); + Value add = Torch::AtenAddScalarOp::create( + rewriter, loc, indicesSliceTy, indicesSlice, + dataDims[batchDimCount + i], /*alpha=*/constOne); - indicesSlice = rewriter.create( - loc, indicesSliceTy, lt, add, indicesSlice); + indicesSlice = Torch::AtenWhereSelfOp::create( + rewriter, loc, indicesSliceTy, lt, add, indicesSlice); if (i == 0) { updatedIndices = indicesSlice; continue; } - updatedIndices = rewriter.create( - loc, indicesSliceTy, indicesSlice, updatedIndices, + updatedIndices = Torch::AtenAddTensorOp::create( + rewriter, loc, indicesSliceTy, indicesSlice, updatedIndices, dataDims[batchDimCount + i]); } @@ -1942,69 +1956,71 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( // data by inserting unit dimensions. auto intListTy = rewriter.getType( rewriter.getType()); - Value reshapeIndicesSizeList = - rewriter.create(loc, intListTy, - reshapeIndicesDims); + Value reshapeIndicesSizeList = Torch::PrimListConstructOp::create( + rewriter, loc, intListTy, reshapeIndicesDims); auto reshapeIndicesTy = rewriter.getType( reshapeIndicesShape, indicesTy.getOptionalDtype()); - Value reshapedIndices = rewriter.create( - loc, reshapeIndicesTy, updatedIndices, reshapeIndicesSizeList); + Value reshapedIndices = + Torch::AtenViewOp::create(rewriter, loc, reshapeIndicesTy, + updatedIndices, reshapeIndicesSizeList); // step 7. Flatten `q-b-1` dimensions of the indices. auto flattenIndicesTy = rewriter.getType( flattenIndicesShape, indicesTy.getOptionalDtype()); - Value batchDimCountVal = rewriter.create( - loc, rewriter.getI64IntegerAttr(batchDimCount)); + Value batchDimCountVal = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(batchDimCount)); Value flattenedIndices = reshapedIndices; if (indicesRank == 1) { - flattenedIndices = rewriter.create( - loc, flattenIndicesTy, reshapedIndices, constZero); + flattenedIndices = Torch::AtenUnsqueezeOp::create( + rewriter, loc, flattenIndicesTy, reshapedIndices, constZero); } else if (indicesRank > 1) { if (batchDimCount > indicesRank - 2) { - flattenedIndices = rewriter.create( - loc, flattenIndicesTy, reshapedIndices, batchDimCountVal); + flattenedIndices = Torch::AtenUnsqueezeOp::create( + rewriter, loc, flattenIndicesTy, reshapedIndices, + batchDimCountVal); } else { - Value endDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(indicesRank - 2)); - flattenedIndices = rewriter.create( - loc, flattenIndicesTy, reshapedIndices, batchDimCountVal, - endDim); + Value endDim = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(indicesRank - 2)); + flattenedIndices = Torch::AtenFlattenUsingIntsOp::create( + rewriter, loc, flattenIndicesTy, reshapedIndices, + batchDimCountVal, endDim); } } // step 8. Expand `r-b-indices_shape[-1]` dims of flattened indices. auto expandIndicesTy = rewriter.getType( expandIndicesShape, indicesTy.getOptionalDtype()); - Value expandIndicesSizeList = - rewriter.create(loc, intListTy, - expandIndicesDims); - Value constFalse = rewriter.create( - loc, rewriter.getType(), + Value expandIndicesSizeList = Torch::PrimListConstructOp::create( + rewriter, loc, intListTy, expandIndicesDims); + Value constFalse = Torch::ConstantBoolOp::create( + rewriter, loc, rewriter.getType(), rewriter.getBoolAttr(false)); - Value expandedIndices = rewriter.create( - loc, expandIndicesTy, flattenedIndices, expandIndicesSizeList, - /*implicit=*/constFalse); + Value expandedIndices = + Torch::AtenExpandOp::create(rewriter, loc, expandIndicesTy, + flattenedIndices, expandIndicesSizeList, + /*implicit=*/constFalse); // step 9. Flatten indices_shape[-1] dimensions of data. auto flattenDataTy = rewriter.getType( flattenDataShape, dataTy.getOptionalDtype()); - Value endDim = rewriter.create( - loc, + Value endDim = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(batchDimCount + indicesLastDim - 1)); Value flattenedData = data; if (indicesLastDim != 1) { - flattenedData = rewriter.create( - loc, flattenDataTy, data, batchDimCountVal, endDim); + flattenedData = Torch::AtenFlattenUsingIntsOp::create( + rewriter, loc, flattenDataTy, data, batchDimCountVal, endDim); } // step 10. Now we have flattenedData and expandedIndices of same rank // to perform gather operation. auto gatherTy = rewriter.getType( expandIndicesShape, dataTy.getOptionalDtype()); - Value gather = rewriter.create( - loc, gatherTy, flattenedData, batchDimCountVal, expandedIndices, - /*sparseGrad=*/constFalse); + Value gather = + Torch::AtenGatherOp::create(rewriter, loc, gatherTy, flattenedData, + batchDimCountVal, expandedIndices, + /*sparseGrad=*/constFalse); // step 11. Unflatten the collapsed indices dims of gather result. if (indicesRank == 1) { @@ -2019,8 +2035,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return success(); } - Value unflattenSizeList = rewriter.create( - loc, intListTy, unflattenIndicesDims); + Value unflattenSizeList = Torch::PrimListConstructOp::create( + rewriter, loc, intListTy, unflattenIndicesDims); rewriter.replaceOpWithNewOp( binder.op, resultType, gather, batchDimCountVal, unflattenSizeList); return success(); @@ -2046,38 +2062,39 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( int64_t indicesRank = indicesTy.getSizes().size(); axis = axis < 0 ? axis + dataRank : axis; - Value index = rewriter.create( - loc, Torch::IntType::get(ctx), rewriter.getI64IntegerAttr(axis)); + Value index = Torch::ConstantIntOp::create( + rewriter, loc, Torch::IntType::get(ctx), + rewriter.getI64IntegerAttr(axis)); // Apply bounds checking on the input: auto intTy = rewriter.getType(); auto boolTy = rewriter.getType( indicesTy.getSizes(), rewriter.getI1Type()); - Value zero = rewriter.create( - loc, intTy, rewriter.getI64IntegerAttr(0)); - Value one = rewriter.create( - loc, intTy, rewriter.getI64IntegerAttr(1)); + Value zero = Torch::ConstantIntOp::create( + rewriter, loc, intTy, rewriter.getI64IntegerAttr(0)); + Value one = Torch::ConstantIntOp::create(rewriter, loc, intTy, + rewriter.getI64IntegerAttr(1)); Value lt = - rewriter.create(loc, boolTy, indices, zero); + Torch::AtenLtScalarOp::create(rewriter, loc, boolTy, indices, zero); Value dim = - rewriter.create(loc, intTy, data, index); - Value add = rewriter.create(loc, indicesTy, - indices, dim, one); - indices = rewriter.create(loc, indicesTy, lt, - add, indices); + Torch::AtenSizeIntOp::create(rewriter, loc, intTy, data, index); + Value add = Torch::AtenAddScalarOp::create(rewriter, loc, indicesTy, + indices, dim, one); + indices = Torch::AtenWhereSelfOp::create(rewriter, loc, indicesTy, lt, + add, indices); auto intListTy = rewriter.getType( rewriter.getType()); llvm::SmallVector indicesDims; for (int i = 0, s = indicesTy.getSizes().size(); i < s; ++i) { - Value k = rewriter.create(binder.getLoc(), i); - indicesDims.push_back(rewriter.create( - binder.getLoc(), indices, k)); + Value k = Torch::ConstantIntOp::create(rewriter, binder.getLoc(), i); + indicesDims.push_back(Torch::AtenSizeIntOp::create( + rewriter, binder.getLoc(), indices, k)); } - Value indicesSizeList = rewriter.create( - binder.getLoc(), intListTy, indicesDims); + Value indicesSizeList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), intListTy, indicesDims); // Determine the collapsed dim size: auto indicesCt = 1; @@ -2093,21 +2110,21 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( SmallVector{indicesCt}, indicesTy.getOptionalDtype()); if (indicesRank == 0) { - indices = rewriter.create( - binder.getLoc(), flattenTy, indices, zero); + indices = Torch::AtenUnsqueezeOp::create(rewriter, binder.getLoc(), + flattenTy, indices, zero); } else if (indicesRank > 1) { - Value rank = rewriter.create(loc, intTy, indices); - Value end = rewriter.create(loc, rank, one); - indices = rewriter.create( - loc, flattenTy, indices, zero, end); + Value rank = Torch::AtenDimOp::create(rewriter, loc, intTy, indices); + Value end = Torch::AtenSubIntOp::create(rewriter, loc, rank, one); + indices = Torch::AtenFlattenUsingIntsOp::create( + rewriter, loc, flattenTy, indices, zero, end); } llvm::SmallVector gatherShape(dataTy.getSizes()); gatherShape[axis] = indicesCt; auto gatherTy = rewriter.getType( gatherShape, dataTy.getOptionalDtype()); - Value gather = rewriter.create( - loc, gatherTy, data, index, indices); + Value gather = Torch::AtenIndexSelectOp::create(rewriter, loc, gatherTy, + data, index, indices); if (indicesRank == 1) { rewriter.replaceOp(binder.op, gather); @@ -2137,29 +2154,29 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.tensorResultType(resultType) || binder.s64IntegerAttr(axis, "axis", 0)) return failure(); - Value constAxis = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value constAxis = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis)); auto indicesTy = cast(indices.getType()); - Value constZero = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(0)); - Value constOne = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); - Value axisSize = rewriter.create(binder.getLoc(), - data, constAxis); - Value indicesAdd = rewriter.create( - binder.getLoc(), indicesTy, indices, axisSize, constOne); + Value constZero = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value constOne = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value axisSize = Torch::AtenSizeIntOp::create(rewriter, binder.getLoc(), + data, constAxis); + Value indicesAdd = Torch::AtenAddScalarOp::create( + rewriter, binder.getLoc(), indicesTy, indices, axisSize, constOne); auto boolTy = rewriter.getType( indicesTy.getSizes(), rewriter.getI1Type()); - Value lt = rewriter.create( - binder.getLoc(), boolTy, indices, constZero); - indices = rewriter.create( - binder.getLoc(), indicesTy, lt, indicesAdd, indices); + Value lt = Torch::AtenLtScalarOp::create(rewriter, binder.getLoc(), + boolTy, indices, constZero); + indices = Torch::AtenWhereSelfOp::create( + rewriter, binder.getLoc(), indicesTy, lt, indicesAdd, indices); - Value sparseGrad = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value sparseGrad = Torch::ConstantBoolOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getBoolAttr(false)); rewriter.replaceOpWithNewOp( binder.op, resultType, data, constAxis, indices, sparseGrad); @@ -2180,11 +2197,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.tensorResultType(resultType)) return failure(); - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value zero = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); - Value one = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value one = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); auto transpose = [&](Value m) -> Value { @@ -2198,8 +2215,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } auto oty = Torch::ValueTensorType::get(tty.getContext(), shape, tty.getOptionalDtype()); - return rewriter.create(binder.getLoc(), - oty, m, zero, one); + return Torch::AtenTransposeIntOp::create(rewriter, binder.getLoc(), + oty, m, zero, one); }; if (transA) { @@ -2220,8 +2237,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return rewriter.notifyMatchFailure(binder.op, "Expected either 2 or 3 inputs"); - Value mm = - rewriter.create(binder.getLoc(), resultType, a, b); + Value mm = Torch::AtenMmOp::create(rewriter, binder.getLoc(), + resultType, a, b); if (alpha == 1.0 && beta == 1.0) { rewriter.replaceOpWithNewOp( binder.op, resultType, mm, c, one); @@ -2229,11 +2246,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } if (alpha != 1.0 && beta != 1.0) { - Value constAlpha = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value constAlpha = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(alpha)); - mm = rewriter.create( - binder.getLoc(), resultType, mm, constAlpha); + mm = Torch::AtenMulScalarOp::create(rewriter, binder.getLoc(), + resultType, mm, constAlpha); alpha = 1.0; } @@ -2242,8 +2259,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( std::swap(mm, c); } - Value constBeta = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value constBeta = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(beta)); rewriter.replaceOpWithNewOp( binder.op, resultType, mm, c, constBeta); @@ -2272,42 +2289,44 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( ArrayRef resultShape = resultType.getSizes(); SmallVector cstKernel, cstPadding, cstStrides; - Value cstZero = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(0)); - Value cstOne = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value cstZero = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value cstOne = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(1)); for (unsigned i = 2; i < inputRank; i++) { if (inputShape[i] == Torch::kUnknownSize) { - Value dim = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i)); - Value inputDimSize = rewriter.create( - binder.getLoc(), operand, dim); + Value dim = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(i)); + Value inputDimSize = Torch::AtenSizeIntOp::create( + rewriter, binder.getLoc(), operand, dim); cstKernel.push_back(inputDimSize); } else { int64_t kernelSize = inputShape[i] - resultShape[i] + 1; - cstKernel.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(kernelSize))); + cstKernel.push_back(Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(kernelSize))); } cstPadding.push_back(cstZero); cstStrides.push_back(cstOne); } - Value kernelSizeList = rewriter.create( - binder.getLoc(), + Value kernelSizeList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstKernel); - Value paddingList = rewriter.create( - binder.getLoc(), + Value paddingList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstPadding); - Value stridesList = rewriter.create( - binder.getLoc(), + Value stridesList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstStrides); Value cstFalse = - rewriter.create(binder.getLoc(), false); + Torch::ConstantBoolOp::create(rewriter, binder.getLoc(), false); Value cstCeilMode = cstFalse; Value cstCountIncludePad = cstFalse; - Value cstNone = rewriter.create(binder.getLoc()); + Value cstNone = + Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); if (inputRank == 3) { rewriter.replaceOpWithNewOp( @@ -2350,43 +2369,44 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, "Expected result type having sizes"); } SmallVector cstKernel, cstPadding, cstStrides, cstDilations; - Value cstZero = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(0)); - Value cstOne = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value cstZero = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value cstOne = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(1)); for (unsigned i = 2; i < inputRank; i++) { if (inputShape[i] == Torch::kUnknownSize) { - Value dim = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i)); - Value inputDimSize = rewriter.create( - binder.getLoc(), operand, dim); + Value dim = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(i)); + Value inputDimSize = Torch::AtenSizeIntOp::create( + rewriter, binder.getLoc(), operand, dim); cstKernel.push_back(inputDimSize); } else { - cstKernel.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(inputShape[i]))); + cstKernel.push_back(Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(inputShape[i]))); } cstPadding.push_back(cstZero); cstDilations.push_back(cstOne); cstStrides.push_back(cstOne); } - Value kernelSizeList = rewriter.create( - binder.getLoc(), + Value kernelSizeList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstKernel); - Value paddingList = rewriter.create( - binder.getLoc(), + Value paddingList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstPadding); - Value dilationsList = rewriter.create( - binder.getLoc(), + Value dilationsList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstDilations); - Value stridesList = rewriter.create( - binder.getLoc(), + Value stridesList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstStrides); Value cstCeilMode = - rewriter.create(binder.getLoc(), false); + Torch::ConstantBoolOp::create(rewriter, binder.getLoc(), false); if (inputRank == 3) { rewriter.replaceOpWithNewOp( @@ -2434,72 +2454,73 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( ArrayRef resultShape = resultType.getSizes(); SmallVector cstKernel, cstPadding, cstStrides; - Value cstZero = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(0)); - Value cstOne = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value cstZero = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value cstOne = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(1)); Value numElements = cstOne; for (unsigned i = 2; i < inputRank; i++) { if (inputShape[i] == Torch::kUnknownSize) { - Value dim = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i)); - Value inputDimSize = rewriter.create( - binder.getLoc(), operand, dim); + Value dim = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(i)); + Value inputDimSize = Torch::AtenSizeIntOp::create( + rewriter, binder.getLoc(), operand, dim); cstKernel.push_back(inputDimSize); } else { int64_t kernelSize = inputShape[i] - resultShape[i] + 1; - cstKernel.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(kernelSize))); + cstKernel.push_back(Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(kernelSize))); } - numElements = rewriter.create( - binder.getLoc(), rewriter.getType(), + numElements = Torch::AtenMulOp::create( + rewriter, binder.getLoc(), rewriter.getType(), cstKernel.back(), numElements); cstPadding.push_back(cstZero); cstStrides.push_back(cstOne); } - Value kernelSizeList = rewriter.create( - binder.getLoc(), + Value kernelSizeList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstKernel); - Value paddingList = rewriter.create( - binder.getLoc(), + Value paddingList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstPadding); - Value stridesList = rewriter.create( - binder.getLoc(), + Value stridesList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstStrides); Value cstFalse = - rewriter.create(binder.getLoc(), false); + Torch::ConstantBoolOp::create(rewriter, binder.getLoc(), false); Value cstCeilMode = cstFalse; Value cstCountIncludePad = cstFalse; - Value abs = rewriter.create(binder.getLoc(), - inputTensorType, operand); - Value pv = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value abs = Torch::AtenAbsOp::create(rewriter, binder.getLoc(), + inputTensorType, operand); + Value pv = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), p)); - Value pow = rewriter.create( - binder.getLoc(), inputTensorType, abs, pv); + Value pow = Torch::AtenPowTensorScalarOp::create( + rewriter, binder.getLoc(), inputTensorType, abs, pv); Value avgPool; if (inputRank == 3) { - avgPool = rewriter.create( - binder.getLoc(), resultType, pow, kernelSizeList, stridesList, - paddingList, cstCeilMode, cstCountIncludePad); - avgPool = rewriter.create( - binder.getLoc(), resultType, avgPool, numElements); + avgPool = Torch::AtenAvgPool1dOp::create( + rewriter, binder.getLoc(), resultType, pow, kernelSizeList, + stridesList, paddingList, cstCeilMode, cstCountIncludePad); + avgPool = Torch::AtenMulScalarOp::create( + rewriter, binder.getLoc(), resultType, avgPool, numElements); } else if (inputRank == 4) { - avgPool = rewriter.create( - binder.getLoc(), resultType, pow, kernelSizeList, stridesList, - paddingList, cstCeilMode, cstCountIncludePad, + avgPool = Torch::AtenAvgPool2dOp::create( + rewriter, binder.getLoc(), resultType, pow, kernelSizeList, + stridesList, paddingList, cstCeilMode, cstCountIncludePad, /*divisor_override=*/cstOne); } else { // inputRank == 5 - avgPool = rewriter.create( - binder.getLoc(), resultType, pow, kernelSizeList, stridesList, - paddingList, cstCeilMode, cstCountIncludePad, + avgPool = Torch::AtenAvgPool3dOp::create( + rewriter, binder.getLoc(), resultType, pow, kernelSizeList, + stridesList, paddingList, cstCeilMode, cstCountIncludePad, /*divisor_override=*/cstOne); } - Value invP = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value invP = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(double{1.0 / p})); rewriter.replaceOpWithNewOp( binder.op, resultType, avgPool, invP); @@ -2565,57 +2586,57 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } SmallVector cstKernel, cstPadding, cstStrides; - Value cstOne = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value cstOne = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(1)); Value numElements = cstOne; for (int64_t i : kernel) { - cstKernel.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i))); - numElements = rewriter.create( - binder.getLoc(), rewriter.getType(), + cstKernel.push_back(Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(i))); + numElements = Torch::AtenMulOp::create( + rewriter, binder.getLoc(), rewriter.getType(), cstKernel.back(), numElements); } - Value kernelSizeList = rewriter.create( - binder.getLoc(), + Value kernelSizeList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstKernel); Value paddingList = createConstantIntList(binder, rewriter, padding); Value stridesList = createConstantIntList(binder, rewriter, strides); Value cstCeilMode = - rewriter.create(binder.getLoc(), ceilMode); + Torch::ConstantBoolOp::create(rewriter, binder.getLoc(), ceilMode); // onnx lp pool doesn't have countIncludePad attribute but set it to // true so that in 1D case numElements is correctly undoes divison. For // 2D/3D case, division is avoided by divison_override. Value cstCountIncludePad = - rewriter.create(binder.getLoc(), true); - Value pv = rewriter.create( - binder.getLoc(), rewriter.getType(), + Torch::ConstantBoolOp::create(rewriter, binder.getLoc(), true); + Value pv = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), p)); auto inputTensorType = cast(operand.getType()); - Value abs = rewriter.create(binder.getLoc(), - inputTensorType, operand); - Value pow = rewriter.create( - binder.getLoc(), inputTensorType, abs, pv); + Value abs = Torch::AtenAbsOp::create(rewriter, binder.getLoc(), + inputTensorType, operand); + Value pow = Torch::AtenPowTensorScalarOp::create( + rewriter, binder.getLoc(), inputTensorType, abs, pv); Value avgPool; if (rank == 3) { - avgPool = rewriter.create( - binder.getLoc(), resultType, pow, kernelSizeList, stridesList, - paddingList, cstCeilMode, cstCountIncludePad); - avgPool = rewriter.create( - binder.getLoc(), resultType, avgPool, numElements); + avgPool = Torch::AtenAvgPool1dOp::create( + rewriter, binder.getLoc(), resultType, pow, kernelSizeList, + stridesList, paddingList, cstCeilMode, cstCountIncludePad); + avgPool = Torch::AtenMulScalarOp::create( + rewriter, binder.getLoc(), resultType, avgPool, numElements); } else if (rank == 4) { - avgPool = rewriter.create( - binder.getLoc(), resultType, pow, kernelSizeList, stridesList, - paddingList, cstCeilMode, cstCountIncludePad, + avgPool = Torch::AtenAvgPool2dOp::create( + rewriter, binder.getLoc(), resultType, pow, kernelSizeList, + stridesList, paddingList, cstCeilMode, cstCountIncludePad, /*divisor_override=*/cstOne); } else { // rank == 5 - avgPool = rewriter.create( - binder.getLoc(), resultType, pow, kernelSizeList, stridesList, - paddingList, cstCeilMode, cstCountIncludePad, + avgPool = Torch::AtenAvgPool3dOp::create( + rewriter, binder.getLoc(), resultType, pow, kernelSizeList, + stridesList, paddingList, cstCeilMode, cstCountIncludePad, /*divisor_override=*/cstOne); } - Value invP = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value invP = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(double{1.0 / p})); rewriter.replaceOpWithNewOp( binder.op, resultType, avgPool, invP); @@ -2652,22 +2673,22 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( // Convert dtype if stash_type is different from input dtype auto xType = cast(x.getType()); Value cstFalse = - rewriter.create(binder.getLoc(), false); - Value none = rewriter.create(binder.getLoc()); + Torch::ConstantBoolOp::create(rewriter, binder.getLoc(), false); + Value none = Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); if (*stashDtype != xType.getOptionalDtype()) { auto newXType = xType.getWithSizesAndDtype(xType.getOptionalSizes(), *stashDtype); - Value dtypeValue = rewriter.create( - binder.getLoc(), + Value dtypeValue = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(stashTypeIntTorch.value())); - x = rewriter.create( - binder.getLoc(), newXType, x, /*dtype=*/dtypeValue, + x = Torch::AtenToDtypeOp::create( + rewriter, binder.getLoc(), newXType, x, /*dtype=*/dtypeValue, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); } - Value constEpsilon = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value constEpsilon = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(epsilon)); unsigned rank = 1; if (std::optional maybeRank = Torch::getTensorRank(x)) @@ -2680,11 +2701,12 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } ArrayRef xShape = xType.getSizes(); for (int64_t n = axis; n < rank; n++) { - normalized.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(xShape[n]))); + normalized.push_back(Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(xShape[n]))); } - Value normalized_shape = rewriter.create( - binder.getLoc(), + Value normalized_shape = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), normalized); @@ -2693,8 +2715,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( reducedShape[i] = xShape[i]; auto reducedType = xType.getWithSizesAndDtype(reducedShape, *stashDtype); - auto y = rewriter.create( - binder.getLoc(), yType, /*meanType=*/reducedType, + auto y = Torch::AtenNativeLayerNormOp::create( + rewriter, binder.getLoc(), yType, /*meanType=*/reducedType, /*invStdDevType=*/reducedType, x, normalized_shape, scale, b, constEpsilon); @@ -2713,12 +2735,14 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( if (*stashDtype != meanType.getOptionalDtype()) { Value constDtype = Torch::getDtypeIntValueForType( rewriter, binder.getLoc(), meanType.getDtype()); - meanOutput = rewriter.create( - binder.getLoc(), meanType, meanOutput, /*dtype=*/constDtype, + meanOutput = Torch::AtenToDtypeOp::create( + rewriter, binder.getLoc(), meanType, meanOutput, + /*dtype=*/constDtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); - varOutput = rewriter.create( - binder.getLoc(), invStdDevType, varOutput, /*dtype=*/constDtype, + varOutput = Torch::AtenToDtypeOp::create( + rewriter, binder.getLoc(), invStdDevType, varOutput, + /*dtype=*/constDtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); } @@ -2726,22 +2750,22 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return success(); }); - patterns.onOp("LeakyRelu", 1, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value operand; - float alpha; - if (binder.tensorOperand(operand) || - binder.tensorResultType(resultType) || - binder.f32FloatAttr(alpha, "alpha", 0.01f)) - return failure(); - Value constAlpha = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getF64FloatAttr(alpha)); - rewriter.replaceOpWithNewOp( - binder.op, resultType, operand, constAlpha); - return success(); - }); + patterns.onOp( + "LeakyRelu", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + float alpha; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType) || + binder.f32FloatAttr(alpha, "alpha", 0.01f)) + return failure(); + Value constAlpha = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(alpha)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, constAlpha); + return success(); + }); patterns.onOp( "LRN", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -2756,14 +2780,14 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.f32FloatAttr(bias, "bias", 1.0f)) return failure(); Type dtype = resultType.getOptionalDtype(); - Value constAlpha = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value constAlpha = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(alpha)); - Value constBeta = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value constBeta = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(beta)); - Value constBias = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value constBias = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(bias)); // Please refer to the operator description // for more info on the lowering @@ -2773,8 +2797,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Location loc = binder.getLoc(); Torch::ValueTensorType inTy = cast(operand.getType()); - Value sqOperand = rewriter.create( - loc, inTy, operand, operand); + Value sqOperand = Torch::AtenMulTensorOp::create(rewriter, loc, inTy, + operand, operand); // view it as n x 1 x c x d0 x d.. if (!inTy.hasSizes()) { return rewriter.notifyMatchFailure(binder.op, @@ -2796,14 +2820,14 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( rewriter.getType(viewShapeInt, dtype); Value viewShapeListVal = createConstantIntList(binder, rewriter, viewShapeInt); - auto view = rewriter.create( - loc, reshapeType, sqOperand, viewShapeListVal); + auto view = Torch::AtenViewOp::create(rewriter, loc, reshapeType, + sqOperand, viewShapeListVal); // padding int64_t highPad = (size - 1) / 2; int64_t lowPad = (size - 1) - highPad; SmallVector paddingInt{0, 0, 0, 0, lowPad, highPad}; - auto constPadVal = rewriter.create( - loc, rewriter.getType(), + auto constPadVal = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getType(), rewriter.getF64FloatAttr(0.0)); Value paddingListVal = createConstantIntList(binder, rewriter, paddingInt); @@ -2811,8 +2835,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( paddedShapeInt[2] += size - 1; Torch::ValueTensorType paddedType = rewriter.getType(paddedShapeInt, dtype); - auto padded = rewriter.create( - loc, paddedType, view, paddingListVal, constPadVal); + auto padded = Torch::AtenConstantPadNdOp::create( + rewriter, loc, paddedType, view, paddingListVal, constPadVal); // avg_pool3d SmallVector kernelSize{size, 1, 1}; Value kernelSizeList = @@ -2822,36 +2846,38 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( SmallVector padding{0, 0, 0}; Value paddingList = createConstantIntList(binder, rewriter, padding); auto cstCeilMode = - rewriter.create(binder.getLoc(), false); + Torch::ConstantBoolOp::create(rewriter, binder.getLoc(), false); auto cstCountIncludeMode = - rewriter.create(binder.getLoc(), true); - Value cstNone = rewriter.create(binder.getLoc()); + Torch::ConstantBoolOp::create(rewriter, binder.getLoc(), true); + Value cstNone = + Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); // Output of pooling is same reshape(view) type because // of the padding done on the dimensions being pooled. - auto pool = rewriter.create( - loc, reshapeType, padded, kernelSizeList, stridesList, paddingList, - cstCeilMode, cstCountIncludeMode, /*divisor_override=*/cstNone); + auto pool = Torch::AtenAvgPool3dOp::create( + rewriter, loc, reshapeType, padded, kernelSizeList, stridesList, + paddingList, cstCeilMode, cstCountIncludeMode, + /*divisor_override=*/cstNone); // squeeze - auto one = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); + auto one = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(1)); SmallVector squeezeShapeInt{ viewShapeInt[0], viewShapeInt[2], viewShapeInt[3], viewShapeInt[4]}; Torch::ValueTensorType squeezeType = rewriter.getType(squeezeShapeInt, dtype); - auto squeeze = rewriter.create( - loc, squeezeType, pool, one); + auto squeeze = Torch::AtenSqueezeDimOp::create(rewriter, loc, + squeezeType, pool, one); // view as input Type Value intTyShapeList = createConstantIntList(binder, rewriter, inTyShape); - auto viewAsInput = rewriter.create( - loc, inTy, squeeze, intTyShapeList); + auto viewAsInput = Torch::AtenViewOp::create(rewriter, loc, inTy, + squeeze, intTyShapeList); // mul + add + pow + div - auto mul = rewriter.create( - loc, resultType, viewAsInput, constAlpha); - auto add = rewriter.create(loc, resultType, mul, - constBias, one); - auto pow = rewriter.create( - loc, resultType, add, constBeta); + auto mul = Torch::AtenMulScalarOp::create(rewriter, loc, resultType, + viewAsInput, constAlpha); + auto add = Torch::AtenAddScalarOp::create(rewriter, loc, resultType, + mul, constBias, one); + auto pow = Torch::AtenPowTensorScalarOp::create( + rewriter, loc, resultType, add, constBeta); rewriter.replaceOpWithNewOp( binder.op, resultType, operand, pow); @@ -2902,8 +2928,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( padInts = paddingsInts; } for (auto p : padInts) - padsTensorValue.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(p))); + padsTensorValue.push_back(Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(p))); } else { // Get pads shape and rank. The pads tensor is expected to be 1-D // tensor. @@ -2926,19 +2952,19 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( // Extract all the values of 1-D pad tensor and create a list of all // these values as torch.pad op expects pad list. - Value constZero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); + Value constZero = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(0)); SmallVector emptyShape; Type padsElemType = Torch::ValueTensorType::get( padsTensorType.getContext(), emptyShape, padsTensorType.getOptionalDtype()); for (uint32_t i = 0; i < padsSize; ++i) { - Value index = rewriter.create( - loc, rewriter.getI64IntegerAttr(i)); - auto select = rewriter.create( - loc, padsElemType, pads, constZero, index); - Value selectInt = rewriter.create( - loc, rewriter.getType(), select); + Value index = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(i)); + auto select = Torch::AtenSelectIntOp::create( + rewriter, loc, padsElemType, pads, constZero, index); + Value selectInt = Torch::AtenItemOp::create( + rewriter, loc, rewriter.getType(), select); padsTensorValue.push_back(selectInt); } } @@ -2955,24 +2981,24 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Type scalarTy = rewriter.getType(); if (isa(constTy.getDtype())) scalarTy = rewriter.getType(); - constantValue = rewriter.create(loc, scalarTy, - constantValue); + constantValue = Torch::AtenItemOp::create(rewriter, loc, scalarTy, + constantValue); } } if (!constantValue && cstMode) { auto dataTensorType = cast(data.getType()); if (isa(dataTensorType.getDtype())) - constantValue = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); + constantValue = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(0)); // Earlier versions used a FLOAT attribute to store the constant // value. The following will pick up on any non-default value attr if // provided. float constantFloat; if (isa(dataTensorType.getDtype()) && !binder.f32FloatAttr(constantFloat, "value", 0.0f)) - constantValue = rewriter.create( - loc, rewriter.getF64FloatAttr(constantFloat)); + constantValue = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(constantFloat)); if (!constantValue) return rewriter.notifyMatchFailure( @@ -2981,7 +3007,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( // for modes other than "constant" a value is not required if (!cstMode) - constantValue = rewriter.create(loc); + constantValue = Torch::ConstantNoneOp::create(rewriter, loc); llvm::SmallVector begins; llvm::SmallVector ends; @@ -3000,8 +3026,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( int64_t rank = dataTensorType.getSizes().size(); auto boolTy = rewriter.getType(); auto intTy = rewriter.getType(); - Value constZero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); + Value constZero = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(0)); // Extract the values: int64_t numAxes = axesTy.getSizes()[0]; @@ -3009,24 +3035,24 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( axesTy.getContext(), ArrayRef{}, axesTy.getOptionalDtype()); llvm::SmallVector axesExtracted; - Value rankV = rewriter.create( - loc, rewriter.getI64IntegerAttr(rank)); + Value rankV = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(rank)); for (uint32_t i = 0; i < numAxes; ++i) { - Value index = rewriter.create( - loc, rewriter.getI64IntegerAttr(i)); - auto select = rewriter.create( - loc, axesElemType, axes, constZero, index); - Value selectInt = rewriter.create( - loc, rewriter.getType(), select); - - Value negAxis = rewriter.create( - loc, boolTy, selectInt, constZero); + Value index = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(i)); + auto select = Torch::AtenSelectIntOp::create( + rewriter, loc, axesElemType, axes, constZero, index); + Value selectInt = Torch::AtenItemOp::create( + rewriter, loc, rewriter.getType(), select); + + Value negAxis = Torch::AtenLtIntOp::create(rewriter, loc, boolTy, + selectInt, constZero); negAxis = - rewriter.create(loc, intTy, negAxis); - Value axis = rewriter.create(loc, intTy, - negAxis, rankV); - axis = rewriter.create(loc, intTy, axis, - selectInt); + Torch::AtenIntBoolOp::create(rewriter, loc, intTy, negAxis); + Value axis = Torch::AtenMulIntOp::create(rewriter, loc, intTy, + negAxis, rankV); + axis = Torch::AtenAddIntOp::create(rewriter, loc, intTy, axis, + selectInt); axesExtracted.push_back(axis); } @@ -3036,27 +3062,27 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( for (int j = 0; j < rank; ++j) { Value newBegin = constZero; Value newEnd = constZero; - Value iv = rewriter.create( - loc, rewriter.getI64IntegerAttr(j)); + Value iv = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(j)); for (size_t i = 0; i < axesExtracted.size(); ++i) { Value begin = begins[i]; Value end = ends[i]; - Value sameAxis = rewriter.create( - loc, boolTy, axesExtracted[i], iv); + Value sameAxis = Torch::AtenEqIntOp::create(rewriter, loc, boolTy, + axesExtracted[i], iv); sameAxis = - rewriter.create(loc, intTy, sameAxis); + Torch::AtenIntBoolOp::create(rewriter, loc, intTy, sameAxis); - begin = rewriter.create(loc, intTy, sameAxis, - begin); - end = rewriter.create(loc, intTy, sameAxis, - end); + begin = Torch::AtenMulIntOp::create(rewriter, loc, intTy, + sameAxis, begin); + end = Torch::AtenMulIntOp::create(rewriter, loc, intTy, sameAxis, + end); - newBegin = rewriter.create(loc, intTy, - newBegin, begin); - newEnd = - rewriter.create(loc, intTy, newEnd, end); + newBegin = Torch::AtenAddIntOp::create(rewriter, loc, intTy, + newBegin, begin); + newEnd = Torch::AtenAddIntOp::create(rewriter, loc, intTy, newEnd, + end); } newBegins.push_back(newBegin); @@ -3080,11 +3106,10 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } Value padsSizeList = - rewriter - .create( - loc, - Torch::ListType::get(rewriter.getType()), - padsRearrange) + Torch::PrimListConstructOp::create( + rewriter, loc, + Torch::ListType::get(rewriter.getType()), + padsRearrange) .getResult(); // lowering to AtenConstantPadNdOp directly allows passing any torch @@ -3100,8 +3125,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( mode = (mode == "edge") ? "replicate" : mode; mode = (mode == "wrap") ? "circular" : mode; - Value modeVal = rewriter.create( - loc, rewriter.getStringAttr(mode)); + Value modeVal = Torch::ConstantStrOp::create( + rewriter, loc, rewriter.getStringAttr(mode)); rewriter.replaceOpWithNewOp( binder.op, resultType, data, padsSizeList, modeVal, constantValue); @@ -3120,9 +3145,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } auto loc = binder.getLoc(); - Value cstFalse = rewriter.create( - loc, rewriter.getBoolAttr(false)); - Value none = rewriter.create(loc); + Value cstFalse = Torch::ConstantBoolOp::create( + rewriter, loc, rewriter.getBoolAttr(false)); + Value none = Torch::ConstantNoneOp::create(rewriter, loc); auto powType = resultType; if (isa(resultType.getDtype())) { @@ -3130,8 +3155,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( resultType.getSizes(), rewriter.getF64Type()); } - Value pow = rewriter.create(loc, powType, - lhs, rhs); + Value pow = Torch::AtenPowTensorTensorOp::create(rewriter, loc, powType, + lhs, rhs); if (!isa(resultType.getDtype())) { rewriter.replaceOp(binder.op, pow); @@ -3139,30 +3164,31 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } auto outDtype = Torch::getScalarTypeForType(resultType.getDtype()); - auto outTyConst = rewriter.create( - binder.getLoc(), rewriter.getType(), + auto outTyConst = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), static_cast(outDtype))); - pow = rewriter.create(loc, powType, pow); + pow = Torch::AtenRoundOp::create(rewriter, loc, powType, pow); rewriter.replaceOpWithNewOp( binder.op, resultType, pow, outTyConst, cstFalse, cstFalse, none); return success(); }); - patterns.onOp( - "Identity", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value tensor; - if (binder.tensorOperand(tensor) || - binder.tensorResultType(resultType)) { - return failure(); - } - Value noneVal = rewriter.create(binder.getLoc()); - rewriter.replaceOpWithNewOp( - binder.op, resultType, tensor, /*memory_format=*/noneVal); - return success(); - }); + patterns.onOp("Identity", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value tensor; + if (binder.tensorOperand(tensor) || + binder.tensorResultType(resultType)) { + return failure(); + } + Value noneVal = + Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); + rewriter.replaceOpWithNewOp( + binder.op, resultType, tensor, /*memory_format=*/noneVal); + return success(); + }); patterns.onOp( "Mean", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { if (binder.op->getNumOperands() == 1) { @@ -3176,18 +3202,19 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Torch::ValueTensorType resultType; SmallVector valList; int64_t numOperands = binder.op->getNumOperands(); - Value numOperandsConstant = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value numOperandsConstant = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), numOperands)); if (binder.tensorOperands(valList, numOperands) || binder.tensorResultType(resultType)) return failure(); - Value constOne = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value constOne = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); // Short circuit to binary add - Value curr = rewriter.create( - binder.getLoc(), resultType, valList[0], valList[1], constOne); + Value curr = Torch::AtenAddTensorOp::create(rewriter, binder.getLoc(), + resultType, valList[0], + valList[1], constOne); if (numOperands == 2) { rewriter.replaceOpWithNewOp( binder.op, resultType, curr, numOperandsConstant); @@ -3198,11 +3225,13 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op->getContext()); for (int i = 2; i < numOperands; i++) { if (i == numOperands - 1) { - curr = rewriter.create( - binder.getLoc(), resultType, curr, valList[i], constOne); + curr = Torch::AtenAddTensorOp::create(rewriter, binder.getLoc(), + resultType, curr, valList[i], + constOne); } else { - curr = rewriter.create( - binder.getLoc(), baseType, curr, valList[i], constOne); + curr = Torch::AtenAddTensorOp::create(rewriter, binder.getLoc(), + baseType, curr, valList[i], + constOne); } } rewriter.replaceOpWithNewOp( @@ -3223,19 +3252,19 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } if (neg == 0) { // replace all negative infs with 0 - tensor = rewriter.create( - binder.getLoc(), + tensor = Torch::AtenReluOp::create( + rewriter, binder.getLoc(), dyn_cast(tensor.getType()), tensor); } if (pos == 0) { // first use neg op to flip positive inf to negative inf. Then relu to // replace all positive infs with 0. - Value flip = rewriter.create( - binder.getLoc(), + Value flip = Torch::AtenNegOp::create( + rewriter, binder.getLoc(), dyn_cast(tensor.getType()), tensor); - tensor = rewriter.create( - binder.getLoc(), dyn_cast(flip.getType()), - flip); + tensor = Torch::AtenReluOp::create( + rewriter, binder.getLoc(), + dyn_cast(flip.getType()), flip); } rewriter.replaceOpWithNewOp(binder.op, resultType, tensor); @@ -3330,24 +3359,24 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Type intTy = rewriter.getType(); Type floatTy = rewriter.getType(); Type depthETy = depthIsInt ? intTy : floatTy; - depth = rewriter.create(loc, depthETy, depth); + depth = Torch::AtenItemOp::create(rewriter, loc, depthETy, depth); if (!depthIsInt) - depth = rewriter.create( - loc, rewriter.getType(), depth); + depth = Torch::AtenIntScalarOp::create( + rewriter, loc, rewriter.getType(), depth); Type boolTy = rewriter.getType( indicesTy.getSizes(), rewriter.getI1Type()); - Value zero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - Value one = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); + Value zero = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(0)); + Value one = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(1)); Value lt = - rewriter.create(loc, boolTy, indices, zero); - Value add = rewriter.create( - loc, indicesTy, indices, depth, one); - indices = rewriter.create(loc, indicesTy, lt, - add, indices); + Torch::AtenLtScalarOp::create(rewriter, loc, boolTy, indices, zero); + Value add = Torch::AtenAddScalarOp::create(rewriter, loc, indicesTy, + indices, depth, one); + indices = Torch::AtenWhereSelfOp::create(rewriter, loc, indicesTy, lt, + add, indices); auto selectTy = rewriter.getType( llvm::SmallVector{1}, valuesTy.getDtype()); @@ -3355,13 +3384,13 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( bool valuesAreInt = isa(valuesTy.getDtype()); Type valuesETy = valuesAreInt ? intTy : floatTy; - Value off = rewriter.create(loc, selectTy, - values, zero, zero); - off = rewriter.create(loc, valuesETy, off); + Value off = Torch::AtenSelectIntOp::create(rewriter, loc, selectTy, + values, zero, zero); + off = Torch::AtenItemOp::create(rewriter, loc, valuesETy, off); - Value on = rewriter.create(loc, selectTy, - values, zero, one); - on = rewriter.create(loc, valuesETy, on); + Value on = Torch::AtenSelectIntOp::create(rewriter, loc, selectTy, + values, zero, one); + on = Torch::AtenItemOp::create(rewriter, loc, valuesETy, on); auto i32Ty = rewriter.getIntegerType(32, true); llvm::SmallVector onehotShape(indicesTy.getSizes()); @@ -3369,40 +3398,40 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( auto onehotTy = rewriter.getType(onehotShape, i32Ty); - Value onehot = rewriter.create( - binder.getLoc(), onehotTy, indices, depth); + Value onehot = Torch::AtenOneHotOp::create(rewriter, binder.getLoc(), + onehotTy, indices, depth); for (int i = indicesTy.getSizes().size(); i > axis; --i) { std::swap(onehotShape[i - 1], onehotShape[i]); - Value iv0 = rewriter.create( - loc, rewriter.getI64IntegerAttr(i)); - Value iv1 = rewriter.create( - loc, rewriter.getI64IntegerAttr(i - 1)); + Value iv0 = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(i)); + Value iv1 = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(i - 1)); onehotTy = rewriter.getType(onehotShape, i32Ty); - onehot = rewriter.create(loc, onehotTy, - onehot, iv1, iv0); + onehot = Torch::AtenTransposeIntOp::create(rewriter, loc, onehotTy, + onehot, iv1, iv0); } // Change one hot to an array of booleans to select value: auto i1Ty = rewriter.getI1Type(); auto torchqTy = Torch::getScalarTypeForType(i1Ty); - Value tyConst = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value tyConst = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), static_cast(torchqTy))); onehotTy = rewriter.getType(onehotShape, i1Ty); - Value none = rewriter.create(loc); - Value cstFalse = rewriter.create(loc, false); - onehot = rewriter.create( - loc, onehotTy, onehot, tyConst, + Value none = Torch::ConstantNoneOp::create(rewriter, loc); + Value cstFalse = Torch::ConstantBoolOp::create(rewriter, loc, false); + onehot = Torch::AtenToDtypeOp::create( + rewriter, loc, onehotTy, onehot, tyConst, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); - onehot = rewriter.create(loc, resultType, - onehot, on, off); + onehot = Torch::AtenWhereScalarOp::create(rewriter, loc, resultType, + onehot, on, off); rewriter.replaceOp(binder.op, onehot); return success(); @@ -3452,12 +3481,12 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( if (axisValue < 0) axisValue += inputTy.getSizes().size(); - axis = rewriter.create( - loc, rewriter.getI64IntegerAttr(axisValue)); + axis = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(axisValue)); // torch.argmax - Value constKeepDims = rewriter.create( - loc, rewriter.getType(), + Value constKeepDims = Torch::ConstantBoolOp::create( + rewriter, loc, rewriter.getType(), rewriter.getBoolAttr(false)); SmallVector argmaxShape; @@ -3469,18 +3498,18 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( auto argmaxTy = rewriter.getType( argmaxShape, rewriter.getIntegerType(32, IntegerType::Signed)); - Value argmax = rewriter.create( - loc, argmaxTy, input, axis, constKeepDims); + Value argmax = Torch::AtenArgmaxOp::create(rewriter, loc, argmaxTy, + input, axis, constKeepDims); // one_hot SmallVector onehotShape(argmaxShape); onehotShape.push_back(inputTy.getSizes()[axisValue]); auto onehotTy = rewriter.getType( onehotShape, resultType.getDtype()); - Value numClasses = - rewriter.create(binder.getLoc(), input, axis); - Value onehot = rewriter.create( - binder.getLoc(), onehotTy, argmax, numClasses); + Value numClasses = Torch::AtenSizeIntOp::create( + rewriter, binder.getLoc(), input, axis); + Value onehot = Torch::AtenOneHotOp::create( + rewriter, binder.getLoc(), onehotTy, argmax, numClasses); SmallVector permutation; for (int i = 0; i < axisValue; ++i) @@ -3491,12 +3520,13 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( SmallVector permValues; for (auto d : permutation) { - permValues.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(d))); + permValues.push_back(Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(d))); } - Value permuteDims = rewriter.create( - loc, Torch::ListType::get(rewriter.getType()), + Value permuteDims = Torch::PrimListConstructOp::create( + rewriter, loc, + Torch::ListType::get(rewriter.getType()), permValues); rewriter.replaceOpWithNewOp(binder.op, resultType, onehot, permuteDims); @@ -3514,18 +3544,17 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return failure(); auto loc = binder.getLoc(); - Value cstAxis = rewriter.create( - loc, rewriter.getI64IntegerAttr(axis)); - Value cstP = rewriter.create( - loc, rewriter.getI64IntegerAttr(p)); - Value cstKeepDim = rewriter.create( - loc, rewriter.getBoolAttr(true)); - Value axisPrimList = - rewriter.create( - binder.getLoc(), - rewriter.getType( - rewriter.getType()), - llvm::ArrayRef{cstAxis}); + Value cstAxis = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(axis)); + Value cstP = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(p)); + Value cstKeepDim = Torch::ConstantBoolOp::create( + rewriter, loc, rewriter.getBoolAttr(true)); + Value axisPrimList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), + rewriter.getType( + rewriter.getType()), + llvm::ArrayRef{cstAxis}); SmallVector normSizes(resultType.getSizes()); int64_t rank = normSizes.size(); @@ -3534,8 +3563,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( normSizes[axis] = 1; auto normType = rewriter.getType( normSizes, resultType.getDtype()); - Value norm = rewriter.create( - loc, normType, input, cstP, axisPrimList, cstKeepDim); + Value norm = Torch::AtenNormScalarOptDimOp::create( + rewriter, loc, normType, input, cstP, axisPrimList, + cstKeepDim); rewriter.replaceOpWithNewOp( binder.op, resultType, input, norm); @@ -3656,13 +3686,13 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return rewriter.notifyMatchFailure( binder.op, "unimplemented: stash_type != input dtype"); - Value cstEpsilon = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value cstEpsilon = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr((double)epsilon)); - Value cstNumGroups = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(numGroups)); - Value cstFalse = rewriter.create( - binder.getLoc(), rewriter.getBoolAttr(false)); + Value cstNumGroups = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(numGroups)); + Value cstFalse = Torch::ConstantBoolOp::create( + rewriter, binder.getLoc(), rewriter.getBoolAttr(false)); rewriter.replaceOpWithNewOp( binder.op, resultType, input, cstNumGroups, scale, bias, cstEpsilon, /*cudnn_enabled=*/cstFalse); @@ -3744,17 +3774,19 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( else output = false; - Value cstOutput = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr((int64_t)output)); - Value cstDtype = rewriter.create( - binder.getLoc(), + Value cstOutput = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr((int64_t)output)); + Value cstDtype = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr((int)torch_upstream::ScalarType::Bool)); - Value cstFalse = rewriter.create( - binder.getLoc(), rewriter.getBoolAttr(false)); - Value cstNone = rewriter.create(binder.getLoc()); + Value cstFalse = Torch::ConstantBoolOp::create( + rewriter, binder.getLoc(), rewriter.getBoolAttr(false)); + Value cstNone = + Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); - Value dataList = rewriter.create( - binder.getLoc(), + Value dataList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), rewriter.getType( rewriter.getType()), SmallVector{cstOutput}); @@ -3812,119 +3844,121 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( // to [[x1, y1, x2, y2], ...] auto boxesTensorType = dyn_cast(boxes.getType()); - Value const0 = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - Value const1 = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - Value const2 = rewriter.create( - loc, rewriter.getI64IntegerAttr(2)); - Value const4 = rewriter.create( - loc, rewriter.getI64IntegerAttr(4)); - Value const2F = rewriter.create( - loc, rewriter.getF64FloatAttr(2.0)); + Value const0 = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(0)); + Value const1 = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(1)); + Value const2 = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(2)); + Value const4 = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(4)); + Value const2F = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(2.0)); // extract scaled ranges for regions of interest auto sliceShape = SmallVector{Torch::kUnknownSize, 2}; auto sliceTensorType = rewriter.getType( sliceShape, boxesTensorType.getDtype()); - Value centers = rewriter.create( - loc, sliceTensorType, boxes, const1, const0, const2, const1); - Value sizes = rewriter.create( - loc, sliceTensorType, boxes, const1, const2, const4, const1); - Value halfSizes = rewriter.create( - loc, sizes.getType(), sizes, const2F); - Value x1y1s = rewriter.create( - loc, centers.getType(), centers, halfSizes, const1); - Value x2y2s = rewriter.create( - loc, centers.getType(), centers, halfSizes, const1); + Value centers = Torch::AtenSliceTensorOp::create( + rewriter, loc, sliceTensorType, boxes, const1, const0, const2, + const1); + Value sizes = Torch::AtenSliceTensorOp::create( + rewriter, loc, sliceTensorType, boxes, const1, const2, const4, + const1); + Value halfSizes = Torch::AtenDivScalarOp::create( + rewriter, loc, sizes.getType(), sizes, const2F); + Value x1y1s = Torch::AtenSubTensorOp::create( + rewriter, loc, centers.getType(), centers, halfSizes, const1); + Value x2y2s = Torch::AtenAddTensorOp::create( + rewriter, loc, centers.getType(), centers, halfSizes, const1); Type listElemType = boxesTensorType.getWithSizesAndDtype( /*optionalSizes=*/std::nullopt, /*optionalDtype=*/nullptr); Type listType = Torch::ListType::get(listElemType); - Value tensorList = rewriter.create( - loc, listType, SmallVector{x1y1s, x2y2s}); - boxes = rewriter.create(loc, boxesTensorType, - tensorList, const1); + Value tensorList = Torch::PrimListConstructOp::create( + rewriter, loc, listType, SmallVector{x1y1s, x2y2s}); + boxes = Torch::AtenCatOp::create(rewriter, loc, boxesTensorType, + tensorList, const1); } // TODO: Support score_threshold input // Filter out the boxes if the score < score_threshold if (operands.size() == 5) { - Value scoreThreshold = rewriter.create( - loc, rewriter.getType(), operands[4]); - Value minScores = rewriter.create( - loc, + Value scoreThreshold = Torch::AtenItemOp::create( + rewriter, loc, rewriter.getType(), operands[4]); + Value minScores = Torch::AtenMinOp::create( + rewriter, loc, Torch::ValueTensorType::get(binder.op->getContext(), SmallVector{}, rewriter.getF32Type()), scores); - minScores = rewriter.create( - loc, rewriter.getType(), minScores); + minScores = Torch::AtenItemOp::create( + rewriter, loc, rewriter.getType(), minScores); - Value scoresCond = rewriter.create( - loc, minScores, scoreThreshold); - rewriter.create( - loc, scoresCond, + Value scoresCond = Torch::AtenGeFloatOp::create( + rewriter, loc, minScores, scoreThreshold); + Torch::RuntimeAssertOp::create( + rewriter, loc, scoresCond, rewriter.getStringAttr( "unimplemented: score_threshold should be <= min(scores)")); } // Get max_output_boxes_per_class and iou_threshold - Value cst0 = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - Value cst1 = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); + Value cst0 = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(0)); + Value cst1 = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(1)); Value maxOutputBoxesPerClass = cst0; - Value iouThreshold = rewriter.create( - loc, rewriter.getF64FloatAttr(0.0)); + Value iouThreshold = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(0.0)); if (operands.size() > 3 && !isa(operands[3].getType())) { - iouThreshold = rewriter.create( - loc, rewriter.getType(), operands[3]); + iouThreshold = Torch::AtenItemOp::create( + rewriter, loc, rewriter.getType(), operands[3]); } if (operands.size() > 2 && !isa(operands[2].getType())) { - maxOutputBoxesPerClass = rewriter.create( - loc, rewriter.getType(), operands[2]); + maxOutputBoxesPerClass = Torch::AtenItemOp::create( + rewriter, loc, rewriter.getType(), operands[2]); } auto nmsTy = Torch::ValueTensorType::get( binder.op->getContext(), SmallVector{-1}, rewriter.getIntegerType(64, /*signed=*/true)); - Value result = rewriter.create( - loc, nmsTy, boxes, scores, iouThreshold); + Value result = Torch::TorchvisionNmsOp::create( + rewriter, loc, nmsTy, boxes, scores, iouThreshold); // Slice the result if numOutputBoxes (N) > max_output_boxes_per_class Value numOutputBoxes = - rewriter.create(loc, result, cst0); - Value boxesCond = rewriter.create( - loc, numOutputBoxes, maxOutputBoxesPerClass); + Torch::AtenSizeIntOp::create(rewriter, loc, result, cst0); + Value boxesCond = Torch::AtenGtIntOp::create( + rewriter, loc, numOutputBoxes, maxOutputBoxesPerClass); auto nmsResultTy = Torch::ValueTensorType::get( binder.op->getContext(), SmallVector{resultType.getSizes()[0]}, rewriter.getIntegerType(64, /*signed=*/true)); - auto ifSlice = rewriter.create( - loc, TypeRange({nmsResultTy}), boxesCond); + auto ifSlice = Torch::PrimIfOp::create( + rewriter, loc, TypeRange({nmsResultTy}), boxesCond); { PatternRewriter::InsertionGuard guard(rewriter); rewriter.createBlock(&ifSlice.getThenRegion(), ifSlice.getThenRegion().begin()); - Value curResult = rewriter.create( - loc, nmsResultTy, result, /*dim=*/cst0, /*start=*/cst0, + Value curResult = Torch::AtenSliceTensorOp::create( + rewriter, loc, nmsResultTy, result, /*dim=*/cst0, /*start=*/cst0, /*end=*/maxOutputBoxesPerClass, /*step=*/cst1); - rewriter.create(loc, curResult); + Torch::PrimIfYieldOp::create(rewriter, loc, curResult); } { PatternRewriter::InsertionGuard guard(rewriter); rewriter.createBlock(&ifSlice.getElseRegion(), ifSlice.getElseRegion().begin()); - Value curResult = rewriter.create( - loc, nmsResultTy, result); - rewriter.create(loc, curResult); + Value curResult = Torch::TensorStaticInfoCastOp::create( + rewriter, loc, nmsResultTy, result); + Torch::PrimIfYieldOp::create(rewriter, loc, curResult); } result = ifSlice.getResult(0); @@ -3940,12 +3974,12 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( result = unsqueezedResult.value(); numOutputBoxes = - rewriter.create(loc, result, cst0); + Torch::AtenSizeIntOp::create(rewriter, loc, result, cst0); SmallVector zerosShapeValues{numOutputBoxes}; - zerosShapeValues.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(2))); - Value zerosShapeList = rewriter.create( - loc, + zerosShapeValues.push_back(Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(2))); + Value zerosShapeList = Torch::PrimListConstructOp::create( + rewriter, loc, rewriter.getType( rewriter.getType()), zerosShapeValues); @@ -3957,17 +3991,18 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( llvm::SmallVector zerosShape = {resultShape->front(), 2}; auto zerosTy = Torch::ValueTensorType::get( resultType.getContext(), zerosShape, resultType.getOptionalDtype()); - Value cstNone = rewriter.create(loc); - Value zeros = rewriter.create( - loc, zerosTy, zerosShapeList, cstNone, cstNone, cstNone, cstNone); + Value cstNone = Torch::ConstantNoneOp::create(rewriter, loc); + Value zeros = + Torch::AtenZerosOp::create(rewriter, loc, zerosTy, zerosShapeList, + cstNone, cstNone, cstNone, cstNone); Type listElemType = cast(resultType) .getWithSizesAndDtype(/*optionalSizes=*/std::nullopt, /*optionalDtype=*/nullptr); Type listType = Torch::ListType::get(listElemType); - Value tensorList = rewriter.create( - loc, listType, SmallVector{zeros, result}); + Value tensorList = Torch::PrimListConstructOp::create( + rewriter, loc, listType, SmallVector{zeros, result}); rewriter.replaceOpWithNewOp(binder.op, resultType, tensorList, cst1); return success(); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index d2e3c94733e9..0de8ffa5d5d8 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -86,8 +86,8 @@ LogicalResult reduceOpImpl(OpBinder binder, ConversionPatternRewriter &rewriter, } if (reduceDims.size() == numAxes) { for (auto i : reduceDims) { - axesList.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i))); + axesList.push_back(Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(i))); } } else binder.op->emitWarning( @@ -101,20 +101,21 @@ LogicalResult reduceOpImpl(OpBinder binder, ConversionPatternRewriter &rewriter, if (axesTy.getSizes()[0] == Torch::kUnknownSize) return failure(); - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value zero = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getI64IntegerAttr(0)); SmallVector selectSizes{1}; auto selType = rewriter.getType( selectSizes, axesTy.getOptionalDtype()); for (uint64_t i = 0; i < numAxes; ++i) { - Value iv = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value iv = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getI64IntegerAttr(i)); - Value extract = rewriter.create( - binder.getLoc(), selType, axesVal, zero, iv); - Value dim = rewriter.create( - binder.getLoc(), rewriter.getType(), extract); + Value extract = Torch::AtenSelectIntOp::create( + rewriter, binder.getLoc(), selType, axesVal, zero, iv); + Value dim = Torch::AtenItemOp::create( + rewriter, binder.getLoc(), rewriter.getType(), + extract); axesList.push_back(dim); } } @@ -124,7 +125,7 @@ LogicalResult reduceOpImpl(OpBinder binder, ConversionPatternRewriter &rewriter, if (!binder.s64IntegerArrayAttr(axesInts, "axes", {})) { for (int64_t i : axesInts) { axesList.push_back( - rewriter.create(binder.getLoc(), i)); + Torch::ConstantIntOp::create(rewriter, binder.getLoc(), i)); } } @@ -152,16 +153,16 @@ LogicalResult reduceOpImpl(OpBinder binder, ConversionPatternRewriter &rewriter, "dimensions equal to 1."); for (uint64_t i = 0; i < inputType.getSizes().size(); i++) { axesList.push_back( - rewriter.create(binder.getLoc(), i)); + Torch::ConstantIntOp::create(rewriter, binder.getLoc(), i)); } } - Value dimValueList = rewriter.create( - binder.getLoc(), + Value dimValueList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), axesList); Value keepDimBool = - rewriter.create(binder.getLoc(), keepDims); + Torch::ConstantBoolOp::create(rewriter, binder.getLoc(), keepDims); // If we are using the reduction op as an intermediate op to be passed into // another operation, we might not want to replace the Op. So we create a new // Op and store the result in a variable. @@ -169,13 +170,13 @@ LogicalResult reduceOpImpl(OpBinder binder, ConversionPatternRewriter &rewriter, if (llvm::is_one_of()) operands.push_back( - /*dtype=*/rewriter.create(binder.getLoc())); + /*dtype=*/Torch::ConstantNoneOp::create(rewriter, binder.getLoc())); if (!isIntermediateOp) { rewriter.replaceOpWithNewOp(binder.op, resultType, operands); } else { - storeResult = rewriter.create(binder.getLoc(), - resultType, operands); + storeResult = AtenReductionTypeOp::create(rewriter, binder.getLoc(), + resultType, operands); } return success(); } @@ -203,19 +204,19 @@ Value extractTorchScalar( Type selectionTypeForSome1DTensor = some1DTensorType.getWithSizesAndDtype( ArrayRef{1}, some1DTensorType.getOptionalDtype()); - Value frontDim = rewriter.create(givenLoc, 0); + Value frontDim = Torch::ConstantIntOp::create(rewriter, givenLoc, 0); Value selectionIndex = - rewriter.create(givenLoc, givenIndex); + Torch::ConstantIntOp::create(rewriter, givenLoc, givenIndex); auto someTorchScalarType = getTorchScalarType(some1DTensorType, rewriter); - Value selectionFromGiven1DTensor = rewriter.create( - givenLoc, selectionTypeForSome1DTensor, given1DTensor, frontDim, + Value selectionFromGiven1DTensor = Torch::AtenSelectIntOp::create( + rewriter, givenLoc, selectionTypeForSome1DTensor, given1DTensor, frontDim, selectionIndex); - return rewriter.create(givenLoc, someTorchScalarType, - selectionFromGiven1DTensor); + return Torch::AtenItemOp::create(rewriter, givenLoc, someTorchScalarType, + selectionFromGiven1DTensor); } Value createScalarSublist( @@ -239,8 +240,8 @@ Value createScalarSublist( auto someTorchScalarType = runningScalarSublist.front().getType(); Type someTorchScalarListType = Torch::ListType::get(someTorchScalarType); - return rewriter.create( - givenLoc, someTorchScalarListType, runningScalarSublist); + return Torch::PrimListConstructOp::create( + rewriter, givenLoc, someTorchScalarListType, runningScalarSublist); } } // namespace @@ -283,8 +284,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( auto torchqTy = Torch::getScalarTypeForType(qTensorTy.getDtype()); - Value tyConst = rewriter.create( - loc, rewriter.getType(), + Value tyConst = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), static_cast(torchqTy))); @@ -301,32 +302,33 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( "for floating point output."); if (isPerTensorQuantization) { - scale = rewriter.create( - loc, rewriter.getType(), scale); + scale = Torch::AtenItemOp::create( + rewriter, loc, rewriter.getType(), scale); Type zeropointTy = rewriter.getType(); if (fpResult) zeropointTy = rewriter.getType(); zeropoint = - rewriter.create(loc, zeropointTy, zeropoint); + Torch::AtenItemOp::create(rewriter, loc, zeropointTy, zeropoint); } if (!fpResult) { Value quantize; // Case 1: Per-Tensor Quantization for non-floating point input. if (isPerTensorQuantization) { - quantize = rewriter.create( - loc, qTensorTy, operand, scale, zeropoint, tyConst); + quantize = Torch::AtenQuantizePerTensorOp::create( + rewriter, loc, qTensorTy, operand, scale, zeropoint, tyConst); } else { // Case 2: Per-Channel Quantization for non-floating point input. int64_t axis; if (binder.s64IntegerAttr(axis, "axis", 1)) return failure(); - Value cstAxis = rewriter.create( - loc, rewriter.getI64IntegerAttr(axis)); - quantize = rewriter.create( - loc, qTensorTy, operand, scale, zeropoint, cstAxis, tyConst); + Value cstAxis = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(axis)); + quantize = Torch::AtenQuantizePerChannelOp::create( + rewriter, loc, qTensorTy, operand, scale, zeropoint, cstAxis, + tyConst); } rewriter.replaceOpWithNewOp( binder.op, resultType, quantize); @@ -334,14 +336,14 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( } // Case 3: Per-Tensor Quantization for floating point input. - Value none = rewriter.create(loc); - Value cstFalse = rewriter.create(loc, false); - Value one = rewriter.create( - loc, rewriter.getF64FloatAttr(1.0)); - Value div = rewriter.create( - loc, operand.getType(), operand, scale); - Value add = rewriter.create( - loc, operand.getType(), div, zeropoint, one); + Value none = Torch::ConstantNoneOp::create(rewriter, loc); + Value cstFalse = Torch::ConstantBoolOp::create(rewriter, loc, false); + Value one = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(1.0)); + Value div = Torch::AtenDivScalarOp::create( + rewriter, loc, operand.getType(), operand, scale); + Value add = Torch::AtenAddScalarOp::create( + rewriter, loc, operand.getType(), div, zeropoint, one); rewriter.replaceOpWithNewOp( binder.op, resultType, add, tyConst, @@ -376,8 +378,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( if (isa(vTy.getDtype())) extractTy = rewriter.getType(); - return rewriter.create(binder.getLoc(), extractTy, - v); + return Torch::AtenItemOp::create(rewriter, binder.getLoc(), extractTy, + v); }; inputZp = extract(inputZp); @@ -389,8 +391,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value zp) -> Value { auto ty = cast(v.getType()); auto newTy = getQTorchTypeFromTorchIntType(ty); - return rewriter.create( - binder.getLoc(), newTy, v, scale, zp); + return Torch::Aten_MakePerTensorQuantizedTensorOp::create( + rewriter, binder.getLoc(), newTy, v, scale, zp); }; // The onnx's QLinearConv op allows per channel quantization only for @@ -436,24 +438,24 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // input_dequant = (input - input_zero_point) * input_scale // Converting the input tensor to float32 type. - Value none = rewriter.create(loc); - Value cstFalse = rewriter.create(loc, false); - Value float32Type = rewriter.create( - loc, rewriter.getI64IntegerAttr(/*float32Type*/ 6)); + Value none = Torch::ConstantNoneOp::create(rewriter, loc); + Value cstFalse = Torch::ConstantBoolOp::create(rewriter, loc, false); + Value float32Type = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(/*float32Type*/ 6)); Type f32InputType = rewriter.getType( inputTy.getSizes(), rewriter.getF32Type()); - input = rewriter.create( - loc, f32InputType, input, float32Type, - /*non_blocking=*/cstFalse, - /*copy=*/cstFalse, - /*memory_format=*/none); - - Value cstOne = rewriter.create( - loc, rewriter.getF64FloatAttr(1.0)); - input = rewriter.create( - loc, f32InputType, input, inputZp, cstOne); - input = rewriter.create(loc, f32InputType, - input, inputScale); + input = Torch::AtenToDtypeOp::create(rewriter, loc, f32InputType, + input, float32Type, + /*non_blocking=*/cstFalse, + /*copy=*/cstFalse, + /*memory_format=*/none); + + Value cstOne = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(1.0)); + input = Torch::AtenSubScalarOp::create(rewriter, loc, f32InputType, + input, inputZp, cstOne); + input = Torch::AtenMulScalarOp::create(rewriter, loc, f32InputType, + input, inputScale); // Dequantizing the weight // Shapes of the inputs are as follows: @@ -471,35 +473,35 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // weight_dequant = (weight - weight_zero_point) * weight_scale int64_t diffRank = weightShape.size() - weightScaleShape.size(); for (int i = 1; i <= diffRank; i++) { - Value cstDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(i)); + Value cstDim = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(i)); weightScaleShape.push_back(1); Type weightScaleUnsqueezeType = weightScaleTy.getWithSizesAndDtype( weightScaleShape, weightScaleTy.getOptionalDtype()); - weightScale = rewriter.create( - loc, weightScaleUnsqueezeType, weightScale, cstDim); + weightScale = Torch::AtenUnsqueezeOp::create( + rewriter, loc, weightScaleUnsqueezeType, weightScale, cstDim); weightZpShape.push_back(1); Type weightZpUnsqueezeType = weightZpTy.getWithSizesAndDtype( weightZpShape, weightZpTy.getOptionalDtype()); - weightZp = rewriter.create( - loc, weightZpUnsqueezeType, weightZp, cstDim); + weightZp = Torch::AtenUnsqueezeOp::create( + rewriter, loc, weightZpUnsqueezeType, weightZp, cstDim); } // Converting the weight tensor to float32 type. Type f32WeightType = rewriter.getType( weightShape, rewriter.getF32Type()); - weight = rewriter.create( - loc, f32WeightType, weight, float32Type, - /*non_blocking=*/cstFalse, - /*copy=*/cstFalse, - /*memory_format=*/none); + weight = Torch::AtenToDtypeOp::create(rewriter, loc, f32WeightType, + weight, float32Type, + /*non_blocking=*/cstFalse, + /*copy=*/cstFalse, + /*memory_format=*/none); - weight = rewriter.create( - loc, f32WeightType, weight, weightZp, cstOne); - weight = rewriter.create(loc, f32WeightType, - weight, weightScale); + weight = Torch::AtenSubTensorOp::create(rewriter, loc, f32WeightType, + weight, weightZp, cstOne); + weight = Torch::AtenMulTensorOp::create(rewriter, loc, f32WeightType, + weight, weightScale); // Converting the bias tensor to float32 type. if (bias) { @@ -509,11 +511,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, "Expected bias argument to have sizes"); Type f32BiasType = rewriter.getType( biasTy.getSizes(), rewriter.getF32Type()); - bias = rewriter.create( - loc, f32BiasType, bias, float32Type, - /*non_blocking=*/cstFalse, - /*copy=*/cstFalse, - /*memory_format=*/none); + bias = Torch::AtenToDtypeOp::create(rewriter, loc, f32BiasType, + bias, float32Type, + /*non_blocking=*/cstFalse, + /*copy=*/cstFalse, + /*memory_format=*/none); } } else { @@ -544,38 +546,38 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( : cast(rewriter.getType()); auto outputTy = rewriter.getType( resultType.getOptionalSizes(), convDtype); - Value output = rewriter - .create( - binder.getLoc(), outputTy, newOperands, - newAttributes, binder.op->getRegions().size()) + Value output = Torch::OperatorOp::create( + rewriter, binder.getLoc(), outputTy, newOperands, + newAttributes, binder.op->getRegions().size()) .getResult(0); if (!isPerChannelQuantization) { - Value outScale = rewriter.create( - binder.getLoc(), rewriter.getType(), inputScale, - weightScale); - Value outZp = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value outScale = Torch::AtenMulFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), + inputScale, weightScale); + Value outZp = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); - output = rewriter.create( - binder.getLoc(), outputTy, output, outScale, outZp); + output = Torch::Aten_MakePerTensorQuantizedTensorOp::create( + rewriter, binder.getLoc(), outputTy, output, outScale, outZp); outputTy = rewriter.getType( resultType.getOptionalSizes(), rewriter.getF32Type()); - output = rewriter.create( - binder.getLoc(), outputTy, output); + output = Torch::AtenDequantizeSelfOp::create( + rewriter, binder.getLoc(), outputTy, output); } outputTy = getQTorchTypeFromTorchIntType(resultType); - Value dtyVal = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value dtyVal = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr( rewriter.getIntegerType(64), static_cast( Torch::getScalarTypeForType(outputTy.getDtype())))); - output = rewriter.create( - binder.getLoc(), outputTy, output, outputScale, outputZp, dtyVal); + output = Torch::AtenQuantizePerTensorOp::create( + rewriter, binder.getLoc(), outputTy, output, outputScale, outputZp, + dtyVal); rewriter.replaceOpWithNewOp(binder.op, resultType, output); return success(); @@ -610,8 +612,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, "input `a` and output not supported for non " "per-tensor quantization"); - Value emptyList = rewriter.create( - binder.getLoc(), + Value emptyList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), rewriter.getType( rewriter.getType()), ValueRange{}); @@ -620,16 +622,16 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( if (!vTy.getSizes().empty()) { vTy = rewriter.getType( ArrayRef({}), vTy.getOptionalDtype()); - v = rewriter.create(binder.getLoc(), vTy, v, - emptyList); + v = Torch::AtenReshapeOp::create(rewriter, binder.getLoc(), vTy, v, + emptyList); } Type extractTy = rewriter.getType(); if (isa(vTy.getDtype())) extractTy = rewriter.getType(); - return rewriter.create(binder.getLoc(), extractTy, - v); + return Torch::AtenItemOp::create(rewriter, binder.getLoc(), extractTy, + v); }; aZp = extract(aZp); @@ -641,8 +643,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value zp) -> Value { auto ty = cast(v.getType()); auto newTy = getQTorchTypeFromTorchIntType(ty); - return rewriter.create( - binder.getLoc(), newTy, v, scale, zp); + return Torch::Aten_MakePerTensorQuantizedTensorOp::create( + rewriter, binder.getLoc(), newTy, v, scale, zp); }; // The onnx's QLinearMatMul op allows per-column (per-channel) @@ -688,23 +690,24 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // a_dequant = (a - a_zero_point) * a_scale // Converting the a tensor to float32 type. - Value none = rewriter.create(loc); - Value cstFalse = rewriter.create(loc, false); - Value float32Type = rewriter.create( - loc, rewriter.getI64IntegerAttr(/*float32Type*/ 6)); + Value none = Torch::ConstantNoneOp::create(rewriter, loc); + Value cstFalse = Torch::ConstantBoolOp::create(rewriter, loc, false); + Value float32Type = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(/*float32Type*/ 6)); Type f32aType = rewriter.getType( aTy.getSizes(), rewriter.getF32Type()); - a = rewriter.create(loc, f32aType, a, - float32Type, - /*non_blocking=*/cstFalse, - /*copy=*/cstFalse, - /*memory_format=*/none); - - Value cstOne = rewriter.create( - loc, rewriter.getF64FloatAttr(1.0)); - a = rewriter.create(loc, f32aType, a, aZp, - cstOne); - a = rewriter.create(loc, f32aType, a, aScale); + a = Torch::AtenToDtypeOp::create(rewriter, loc, f32aType, a, + float32Type, + /*non_blocking=*/cstFalse, + /*copy=*/cstFalse, + /*memory_format=*/none); + + Value cstOne = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(1.0)); + a = Torch::AtenSubScalarOp::create(rewriter, loc, f32aType, a, aZp, + cstOne); + a = Torch::AtenMulScalarOp::create(rewriter, loc, f32aType, a, + aScale); // Dequantizing the b // Shapes of the inputs are as follows: @@ -719,15 +722,16 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // Converting the b tensor to float32 type. Type f32bType = rewriter.getType( bShape, rewriter.getF32Type()); - b = rewriter.create(loc, f32bType, b, - float32Type, - /*non_blocking=*/cstFalse, - /*copy=*/cstFalse, - /*memory_format=*/none); - - b = rewriter.create(loc, f32bType, b, bZp, - cstOne); - b = rewriter.create(loc, f32bType, b, bScale); + b = Torch::AtenToDtypeOp::create(rewriter, loc, f32bType, b, + float32Type, + /*non_blocking=*/cstFalse, + /*copy=*/cstFalse, + /*memory_format=*/none); + + b = Torch::AtenSubTensorOp::create(rewriter, loc, f32bType, b, bZp, + cstOne); + b = Torch::AtenMulTensorOp::create(rewriter, loc, f32bType, b, + bScale); } else { llvm_unreachable( "Unidentified case for quantization for `b` argument of" @@ -747,9 +751,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value c; if (cTy.getSizes().size() == 2) { - c = rewriter.create(binder.getLoc(), cTy, a, b); + c = Torch::AtenMmOp::create(rewriter, binder.getLoc(), cTy, a, b); } else { - c = rewriter.create(binder.getLoc(), cTy, a, b); + c = Torch::AtenBmmOp::create(rewriter, binder.getLoc(), cTy, a, b); } if (!isPerColumnQuantization) { @@ -757,31 +761,31 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( resultType.getOptionalSizes(), rewriter.getType()); - Value mmScale = rewriter.create( - binder.getLoc(), rewriter.getType(), aScale, - bScale); - Value mmZp = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value mmScale = Torch::AtenMulFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), + aScale, bScale); + Value mmZp = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); - c = rewriter.create( - binder.getLoc(), cTy, c, mmScale, mmZp); + c = Torch::Aten_MakePerTensorQuantizedTensorOp::create( + rewriter, binder.getLoc(), cTy, c, mmScale, mmZp); cTy = rewriter.getType( resultType.getOptionalSizes(), rewriter.getF32Type()); - c = rewriter.create(binder.getLoc(), cTy, - c); + c = Torch::AtenDequantizeSelfOp::create(rewriter, binder.getLoc(), + cTy, c); } cTy = dyn_cast( getQTorchTypeFromTorchIntType(resultType)); - Value dtyVal = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value dtyVal = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr( rewriter.getIntegerType(64), static_cast( Torch::getScalarTypeForType(cTy.getDtype())))); - c = rewriter.create( - binder.getLoc(), cTy, c, cScale, cZp, dtyVal); + c = Torch::AtenQuantizePerTensorOp::create(rewriter, binder.getLoc(), + cTy, c, cScale, cZp, dtyVal); rewriter.replaceOpWithNewOp(binder.op, resultType, c); return success(); @@ -849,8 +853,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( (axis < -dataRank) || (axis >= dataRank)) return failure(); - Value axisValue = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(axis)); + Value axisValue = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(axis)); rewriter.replaceOpWithNewOp( binder.op, resultTy, data, axisValue, indices, updates); @@ -881,33 +885,33 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( axis += cast(data.getType()).getSizes().size(); - Value constAxis = rewriter.create( - loc, rewriter.getType(), + Value constAxis = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis)); - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value zero = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getI64IntegerAttr(0)); - Value one = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value one = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getI64IntegerAttr(1)); - Value axisSize = rewriter.create( - binder.getLoc(), rewriter.getType(), data, + Value axisSize = Torch::AtenSizeIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), data, constAxis); auto indicesTy = cast(indices.getType()); - Value indicesAdd = rewriter.create( - loc, indicesTy, indices, axisSize, one); + Value indicesAdd = Torch::AtenAddScalarOp::create( + rewriter, loc, indicesTy, indices, axisSize, one); - Value inputNeg = rewriter.create( - loc, + Value inputNeg = Torch::AtenLtScalarOp::create( + rewriter, loc, rewriter.getType(indicesTy.getSizes(), rewriter.getI1Type()), indices, zero); - indices = rewriter.create( - loc, indicesTy, inputNeg, indicesAdd, indices); + indices = Torch::AtenWhereSelfOp::create(rewriter, loc, indicesTy, + inputNeg, indicesAdd, indices); if (reduction == "none") { rewriter.replaceOpWithNewOp( @@ -926,9 +930,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( } Value cstStrReduction = - rewriter.create(binder.getLoc(), reduction); + Torch::ConstantStrOp::create(rewriter, binder.getLoc(), reduction); Value cstTrue = - rewriter.create(binder.getLoc(), true); + Torch::ConstantBoolOp::create(rewriter, binder.getLoc(), true); rewriter.replaceOpWithNewOp( binder.op, resultType, data, constAxis, indices, updates, cstStrReduction, cstTrue); @@ -959,11 +963,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return failure(); Value cstFalse = - rewriter.create(binder.getLoc(), false); - Value none = rewriter.create(binder.getLoc()); + Torch::ConstantBoolOp::create(rewriter, binder.getLoc(), false); + Value none = Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); - Value len = rewriter.create( - binder.getLoc(), rewriter.getType(), x); + Value len = Torch::AtenLenTOp::create( + rewriter, binder.getLoc(), rewriter.getType(), x); // AtenLenTOp returns a torch.int, so we have to // put that in a tensor. @@ -1023,8 +1027,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value y; if (binder.tensorOperands(x, y) || binder.tensorResultType(resultType)) return failure(); - Value const1 = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value const1 = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); rewriter.replaceOpWithNewOp( binder.op, resultType, x, y, const1); @@ -1046,8 +1050,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( if (binder.tensorOperands(valList, numOperands) || binder.tensorResultType(resultType)) return failure(); - Value const1 = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value const1 = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); // Short circuit to binary add if (numOperands == 2) { @@ -1056,12 +1060,14 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return success(); } // When binder.op->getNumOperands() > 2 - Value curr = rewriter.create( - binder.getLoc(), resultType, valList[0], valList[1], const1); + Value curr = Torch::AtenAddTensorOp::create(rewriter, binder.getLoc(), + resultType, valList[0], + valList[1], const1); for (int i = 2; i < numOperands; i++) { if (i == numOperands - 1) { - curr = rewriter.create( - binder.getLoc(), resultType, curr, valList[i], const1); + curr = Torch::AtenAddTensorOp::create(rewriter, binder.getLoc(), + resultType, curr, valList[i], + const1); } else { SmallVector resultBroadcastShapeInt; SmallVector resultBroadcastShapeValue; @@ -1071,8 +1077,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( auto baseType = Torch::ValueTensorType::get( binder.op->getContext(), resultBroadcastShapeInt, resultType.getOptionalDtype()); - curr = rewriter.create( - binder.getLoc(), baseType, curr, valList[i], const1); + curr = Torch::AtenAddTensorOp::create( + rewriter, binder.getLoc(), baseType, curr, valList[i], const1); } } rewriter.replaceOp(binder.op, curr); @@ -1156,8 +1162,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( } } for (auto i : squeezeDims) { - dimList.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i))); + dimList.push_back(Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(i))); } } @@ -1168,23 +1174,25 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( SmallVector selectSizes{1}; Type selectResultType = axesType.getWithSizesAndDtype( selectSizes, axesType.getOptionalDtype()); - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value zero = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); for (int i = 0; i < rankDiff; i++) { // Go through the axes list and get each dim in the list - Value selectIndex = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value selectIndex = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - Value extract = rewriter.create( - binder.getLoc(), selectResultType, axes, zero, selectIndex); - Value dim = rewriter.create( - binder.getLoc(), rewriter.getType(), extract); + Value extract = Torch::AtenSelectIntOp::create( + rewriter, binder.getLoc(), selectResultType, axes, zero, + selectIndex); + Value dim = Torch::AtenItemOp::create( + rewriter, binder.getLoc(), rewriter.getType(), + extract); dimList.push_back(dim); } } - Value dimValueList = rewriter.create( - binder.getLoc(), + Value dimValueList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), rewriter.getType( rewriter.getType()), dimList); @@ -1257,10 +1265,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( unsqueezeShape.insert(unsqueezeShape.begin() + dim, 1); Type unsqueezeType = resultType.getWithSizesAndDtype( unsqueezeShape, resultType.getOptionalDtype()); - Value cstDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(dim)); - result = rewriter.create(loc, unsqueezeType, - result, cstDim); + Value cstDim = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(dim)); + result = Torch::AtenUnsqueezeOp::create(rewriter, loc, unsqueezeType, + result, cstDim); } rewriter.replaceOp(binder.op, result); return success(); @@ -1280,11 +1288,12 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( axis += cast(input.getType()).getSizes().size(); - Value constAxis = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value constAxis = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis)); - Value noneVal = rewriter.create(binder.getLoc()); + Value noneVal = + Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); rewriter.replaceOpWithNewOp( binder.op, resultType, input, constAxis, /*dtype=*/noneVal); @@ -1305,16 +1314,16 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.tensorResultType(resultType)) return failure(); - Value vAlpha = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value vAlpha = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), alpha)); - Value vScale = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value vScale = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), gamma)); - Value vInputScale = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value vInputScale = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), 1.0)); rewriter.replaceOpWithNewOp( @@ -1333,8 +1342,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( 0)) return failure(); - Value data = rewriter.create( - binder.getLoc(), operand.getType(), operand); + Value data = Torch::AtenAbsOp::create(rewriter, binder.getLoc(), + operand.getType(), operand); return reduceOpImpl( binder, rewriter, data, resultType, @@ -1354,8 +1363,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // A ReduceL2 op is equivalent to the following sequence of operations: // Mul(x, x) -> ReduceSum -> CastF32 -> Sqrt -> CastLike(resultType) - Value squareOfOperand = rewriter.create( - binder.getLoc(), operand.getType(), operand, operand); + Value squareOfOperand = Torch::AtenMulTensorOp::create( + rewriter, binder.getLoc(), operand.getType(), operand, operand); auto reducedSum = reduceOpImpl( binder, rewriter, squareOfOperand, resultType, operand, keepDims, @@ -1365,12 +1374,14 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, "Failed to perform sum operation on square of operand"); - Value castDType = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(/*Float32Type*/ 6)); + Value castDType = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(/*Float32Type*/ 6)); - Value noneVal = rewriter.create(binder.getLoc()); + Value noneVal = + Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); Value constFalse = - rewriter.create(binder.getLoc(), false); + Torch::ConstantBoolOp::create(rewriter, binder.getLoc(), false); // Perform an AtenToDtype op on the squared sum of the operand, stored // now in operand itself. @@ -1378,13 +1389,13 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( .getOptionalSizes(); auto f32ResultType = rewriter.getType( size, rewriter.getF32Type()); - Value operandCast = rewriter.create( - binder.getLoc(), f32ResultType, operand, castDType, + Value operandCast = Torch::AtenToDtypeOp::create( + rewriter, binder.getLoc(), f32ResultType, operand, castDType, /*non_blocking=*/constFalse, /*copy=*/constFalse, /*memory_format=*/noneVal); - Value operandSqrt = rewriter.create( - binder.getLoc(), f32ResultType, operandCast); + Value operandSqrt = Torch::AtenSqrtOp::create( + rewriter, binder.getLoc(), f32ResultType, operandCast); Value resultDtype = Torch::getDtypeIntValueForType( rewriter, binder.getLoc(), resultType.getDtype()); @@ -1434,21 +1445,23 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return failure(); // out = Log(reducesum(exp(data))) - Value castDType = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(/*Float64Type*/ 7)); - Value noneVal = rewriter.create(binder.getLoc()); + Value castDType = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(/*Float64Type*/ 7)); + Value noneVal = + Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); Value constFalse = - rewriter.create(binder.getLoc(), false); + Torch::ConstantBoolOp::create(rewriter, binder.getLoc(), false); auto size = dyn_cast(data.getType()).getOptionalSizes(); auto f64ResultType = rewriter.getType( size, rewriter.getF64Type()); - Value dataCast = rewriter.create( - binder.getLoc(), f64ResultType, data, castDType, + Value dataCast = Torch::AtenToDtypeOp::create( + rewriter, binder.getLoc(), f64ResultType, data, castDType, /*non_blocking=*/constFalse, /*copy=*/constFalse, /*memory_format=*/noneVal); - Value dataExp = rewriter.create( - binder.getLoc(), f64ResultType, dataCast); + Value dataExp = Torch::AtenExpOp::create(rewriter, binder.getLoc(), + f64ResultType, dataCast); auto f64ReduceType = rewriter.getType( resultType.getOptionalSizes(), rewriter.getF64Type()); auto reducedSumBool = reduceOpImpl( @@ -1458,8 +1471,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return rewriter.notifyMatchFailure( binder.op, "Failed to perform sum operation on square of operand"); - Value finalResult = rewriter.create( - binder.getLoc(), f64ReduceType, data); + Value finalResult = Torch::AtenLogOp::create(rewriter, binder.getLoc(), + f64ReduceType, data); Value resultDtype = Torch::getDtypeIntValueForType( rewriter, binder.getLoc(), resultType.getDtype()); rewriter.replaceOpWithNewOp( @@ -1496,8 +1509,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( "noop_with_empty_axes", 0)) return failure(); - Value dataSquare = rewriter.create( - binder.getLoc(), data.getType(), data, data); + Value dataSquare = Torch::AtenMulTensorOp::create( + rewriter, binder.getLoc(), data.getType(), data, data); return reduceOpImpl( binder, rewriter, dataSquare, resultType, @@ -1550,8 +1563,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( if (FloatType fpTy = dyn_cast(dty)) { auto inf = APFloat::getInf(fpTy.getFloatSemantics(), /*Negative=*/true); - scalar = rewriter.create( - binder.getLoc(), rewriter.getType(), + scalar = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), inf.convertToDouble())); } @@ -1561,8 +1574,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( intTy.isSigned() ? APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth()) : APInt::getMinValue(intTy.getIntOrFloatBitWidth()); - scalar = rewriter.create( - binder.getLoc(), torchIntTy, + scalar = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), torchIntTy, rewriter.getIntegerAttr(rewriter.getIntegerType(64), minInt.getSExtValue())); } @@ -1571,21 +1584,23 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( for (int i = 0, s = resultType.getSizes().size(); i < s; ++i) { auto staticDim = resultType.getSizes()[i]; if (staticDim != Torch::kUnknownSize) { - fillDims.push_back(rewriter.create( - binder.getLoc(), torchIntTy, + fillDims.push_back(Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), torchIntTy, rewriter.getI64IntegerAttr(staticDim))); continue; } - Value iv = rewriter.create( - binder.getLoc(), torchIntTy, rewriter.getI64IntegerAttr(i)); - fillDims.push_back(rewriter.create( - binder.getLoc(), torchIntTy, data, iv)); + Value iv = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), torchIntTy, + rewriter.getI64IntegerAttr(i)); + fillDims.push_back(Torch::AtenSizeIntOp::create( + rewriter, binder.getLoc(), torchIntTy, data, iv)); } - Value none = rewriter.create(binder.getLoc()); - Value fillDimsList = rewriter.create( - binder.getLoc(), Torch::ListType::get(torchIntTy), fillDims); + Value none = Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); + Value fillDimsList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get(torchIntTy), + fillDims); rewriter.replaceOpWithNewOp( binder.op, resultType, fillDimsList, scalar, none, none, none, none); @@ -1624,8 +1639,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value scalar; if (FloatType fpTy = dyn_cast(dty)) { auto inf = APFloat::getInf(fpTy.getFloatSemantics()); - scalar = rewriter.create( - binder.getLoc(), rewriter.getType(), + scalar = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), inf.convertToDouble())); } @@ -1635,8 +1650,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( intTy.isSigned() ? APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth()) : APInt::getMaxValue(intTy.getIntOrFloatBitWidth()); - scalar = rewriter.create( - binder.getLoc(), torchIntTy, + scalar = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), torchIntTy, rewriter.getIntegerAttr(rewriter.getIntegerType(64), mx.getSExtValue())); } @@ -1645,21 +1660,23 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( for (int i = 0, s = resultType.getSizes().size(); i < s; ++i) { auto staticDim = resultType.getSizes()[i]; if (staticDim != Torch::kUnknownSize) { - fillDims.push_back(rewriter.create( - binder.getLoc(), torchIntTy, + fillDims.push_back(Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), torchIntTy, rewriter.getI64IntegerAttr(staticDim))); continue; } - Value iv = rewriter.create( - binder.getLoc(), torchIntTy, rewriter.getI64IntegerAttr(i)); - fillDims.push_back(rewriter.create( - binder.getLoc(), torchIntTy, data, iv)); + Value iv = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), torchIntTy, + rewriter.getI64IntegerAttr(i)); + fillDims.push_back(Torch::AtenSizeIntOp::create( + rewriter, binder.getLoc(), torchIntTy, data, iv)); } - Value none = rewriter.create(binder.getLoc()); - Value fillDimsList = rewriter.create( - binder.getLoc(), Torch::ListType::get(torchIntTy), fillDims); + Value none = Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); + Value fillDimsList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get(torchIntTy), + fillDims); rewriter.replaceOpWithNewOp( binder.op, resultType, fillDimsList, scalar, none, none, none, none); @@ -1693,8 +1710,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( auto shapeType = Torch::ValueTensorType::get( binder.op->getContext(), SmallVector{inputRank}, resultType.getOptionalDtype()); - Value shape = rewriter.create( - binder.getLoc(), shapeType, operand); + Value shape = Torch::Aten_ShapeAsTensorOp::create( + rewriter, binder.getLoc(), shapeType, operand); if (inputRank == 0) { rewriter.replaceOpWithNewOp( @@ -1707,12 +1724,12 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return success(); } - Value sv = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(start)); - Value ev = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(end)); - Value step = rewriter.create(binder.getLoc(), 1); - Value dim = rewriter.create(binder.getLoc(), 0); + Value sv = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(start)); + Value ev = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(end)); + Value step = Torch::ConstantIntOp::create(rewriter, binder.getLoc(), 1); + Value dim = Torch::ConstantIntOp::create(rewriter, binder.getLoc(), 0); rewriter.replaceOpWithNewOp( binder.op, resultType, shape, dim, sv, ev, step); @@ -1772,28 +1789,28 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( if (dim < 0) dim += selfTy.getSizes().size(); - Value dimValue = rewriter.create( - loc, rewriter.getType(), + Value dimValue = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getType(), rewriter.getI64IntegerAttr(dim)); - Value vNumOutputs = rewriter.create( - loc, rewriter.getType(), + Value vNumOutputs = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getType(), rewriter.getI64IntegerAttr(numOutputs)); - Value one = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - Value zero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); + Value one = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(1)); + Value zero = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(0)); - Value vDimSize = rewriter.create( - loc, rewriter.getType(), self, dimValue); + Value vDimSize = Torch::AtenSizeIntOp::create( + rewriter, loc, rewriter.getType(), self, dimValue); Value addNumOutputs = - rewriter.create(loc, vDimSize, vNumOutputs); + Torch::AtenAddIntOp::create(rewriter, loc, vDimSize, vNumOutputs); Value subOne = - rewriter.create(loc, addNumOutputs, one); - Value splitSize = - rewriter.create(loc, subOne, vNumOutputs); + Torch::AtenSubIntOp::create(rewriter, loc, addNumOutputs, one); + Value splitSize = Torch::AtenFloordivIntOp::create(rewriter, loc, + subOne, vNumOutputs); llvm::SmallVector outputs; Value step = one; @@ -1801,16 +1818,16 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( for (int i = 0; i < numOutputs - 1; ++i) { Value end = - rewriter.create(loc, start, splitSize); - Value slice = rewriter.create( - loc, result0Ty, self, dimValue, start, end, step); + Torch::AtenAddIntOp::create(rewriter, loc, start, splitSize); + Value slice = Torch::AtenSliceTensorOp::create( + rewriter, loc, result0Ty, self, dimValue, start, end, step); start = end; outputs.push_back(slice); } Value end = vDimSize; - Value lastSlice = rewriter.create( - loc, resultNTy, self, dimValue, start, end, step); + Value lastSlice = Torch::AtenSliceTensorOp::create( + rewriter, loc, resultNTy, self, dimValue, start, end, step); outputs.push_back(lastSlice); rewriter.replaceOp(binder.op, outputs); @@ -1863,12 +1880,12 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( intermediateShape[dim] = d == intermediateShape[dim] ? d : -1; } - Torch::PrimTolistOp splitToList = rewriter.create( - binder.getLoc(), + Torch::PrimTolistOp splitToList = Torch::PrimTolistOp::create( + rewriter, binder.getLoc(), Torch::ListType::get(rewriter.getType()), split); - Value dimValue = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value dimValue = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), dim)); // TODO: Attempting to use the shape expected by the ONNX mlir as ground @@ -1878,8 +1895,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( /*std::optional>=*/intermediateShape, result0Ty.getOptionalDtype())); Torch::AtenSplitWithSizesOp new_op = - rewriter.create( - binder.getLoc(), resultOuterType, self, + Torch::AtenSplitWithSizesOp::create( + rewriter, binder.getLoc(), resultOuterType, self, splitToList.getResult(0), dimValue); // the onnx op is variadic with multiple results, but AtenSplitWithSizes @@ -1966,16 +1983,16 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( std::swap(shape[i], shape[target]); std::swap(current[i], current[target]); - Value dim0 = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value dim0 = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - Value dim1 = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value dim1 = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), target)); - operand = rewriter.create( - loc, + operand = Torch::AtenTransposeIntOp::create( + rewriter, loc, Torch::ValueTensorType::get(tensorType.getContext(), shape, operandType.getOptionalDtype()), operand, dim0, dim1); @@ -2040,17 +2057,18 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( } else { // The default `steps` value is a 1d tensor filled with ones with a // size equal to the size of `starts` and `ends`. - Value none = rewriter.create(loc); - Value sizeStepInput = rewriter.create( - loc, rewriter.getType(), + Value none = Torch::ConstantNoneOp::create(rewriter, loc); + Value sizeStepInput = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), startSize)); - Value sizeStepsInput = rewriter.create( - loc, + Value sizeStepsInput = Torch::PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get( Torch::IntType::get(binder.op->getContext())), sizeStepInput); - steps = rewriter.create( - loc, startsTorchTy, sizeStepsInput, none, none, none, none); + steps = + Torch::AtenOnesOp::create(rewriter, loc, startsTorchTy, + sizeStepsInput, none, none, none, none); } if (!(endsTy.getRank() == 1 && startsTy.getRank() == 1 && @@ -2077,19 +2095,19 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return rewriter.notifyMatchFailure( binder.op, "Steps should be the same size of starts and ends"); - Value zero = rewriter.create( - loc, rewriter.getType(), + Value zero = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); auto select = [&](Value v, Value k) -> Value { auto ty = cast(v.getType()); - auto sel = rewriter.create( - loc, + auto sel = Torch::AtenIndexSelectOp::create( + rewriter, loc, Torch::ValueTensorType::get(ty.getContext(), ArrayRef{1}, ty.getOptionalDtype()), v, zero, k); - Value item = rewriter.create( - loc, rewriter.getType(), sel); + Value item = Torch::AtenItemOp::create( + rewriter, loc, rewriter.getType(), sel); return item; }; @@ -2104,11 +2122,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( context, intermediateShape, resultTorchType.getOptionalDtype()); for (int i = 0; i < endSize; ++i) { - Value k = rewriter.create( - loc, rewriter.getType(), + Value k = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - Value kTensor = rewriter.create( - loc, + Value kTensor = Torch::PrimNumToTensorScalarOp::create( + rewriter, loc, Torch::ValueTensorType::get( context, ArrayRef{1}, rewriter.getIntegerType(64, /*signed*/ 1)), @@ -2121,8 +2139,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( auto sliceType = intermediateType; sliceType = i == (endSize - 1) ? resultTorchType : sliceType; - operand = rewriter.create( - loc, sliceType, operand, axis, start, end, step); + operand = Torch::AtenSliceTensorOp::create( + rewriter, loc, sliceType, operand, axis, start, end, step); } rewriter.replaceOp(binder.op, operand); @@ -2147,11 +2165,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( if (hasStaticShape) { SmallVector resultShape; for (int64_t dim : resultShapeInt) { - resultShape.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(dim))); + resultShape.push_back(Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(dim))); } - Value resultShapeList = rewriter.create( - binder.getLoc(), + Value resultShapeList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get( Torch::IntType::get(binder.op->getContext())), resultShape); @@ -2172,27 +2190,29 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( dyn_cast(shape.getType()).getSizes(); auto dataSizes = dyn_cast(data.getType()).getSizes(); - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value zero = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); if (allowzero == 0) { // convert shape (tensor) into torch int list while dealing with zero // vals for (int i = 0; i < shapeSizes[0]; i++) { // Go through the shape list and get each dim in the list - Value selectIndex = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value selectIndex = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - Value extract = rewriter.create( - binder.getLoc(), selectResultType, shape, zero, selectIndex); - Value dim = rewriter.create( - binder.getLoc(), rewriter.getType(), extract); + Value extract = Torch::AtenSelectIntOp::create( + rewriter, binder.getLoc(), selectResultType, shape, zero, + selectIndex); + Value dim = Torch::AtenItemOp::create( + rewriter, binder.getLoc(), rewriter.getType(), + extract); // deal with zero axis values: replace with original dim value in // input - Value isZero = - rewriter.create(binder.getLoc(), dim, zero); + Value isZero = Torch::AtenEqIntOp::create(rewriter, binder.getLoc(), + dim, zero); isZero = - rewriter.create(binder.getLoc(), isZero); + Torch::AtenIntBoolOp::create(rewriter, binder.getLoc(), isZero); int64_t dataRank = dataSizes.size(); if (i < dataRank) { @@ -2202,26 +2222,26 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( ArrayRef(), int64Ty); auto boolTy = rewriter.getType( ArrayRef(), rewriter.getI1Type()); - Value iv = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i)); - Value inDim = rewriter.create( - binder.getLoc(), torchIntTy, data, iv); - isZero = rewriter.create( - binder.getLoc(), boolTy, isZero); - inDim = rewriter.create( - binder.getLoc(), dimTy, inDim); - dim = rewriter.create( - binder.getLoc(), dimTy, dim); - Value finalDim = rewriter.create( - binder.getLoc(), dimTy, isZero, inDim, dim); - dim = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value iv = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(i)); + Value inDim = Torch::AtenSizeIntOp::create( + rewriter, binder.getLoc(), torchIntTy, data, iv); + isZero = Torch::PrimNumToTensorScalarOp::create( + rewriter, binder.getLoc(), boolTy, isZero); + inDim = Torch::PrimNumToTensorScalarOp::create( + rewriter, binder.getLoc(), dimTy, inDim); + dim = Torch::PrimNumToTensorScalarOp::create( + rewriter, binder.getLoc(), dimTy, dim); + Value finalDim = Torch::AtenWhereSelfOp::create( + rewriter, binder.getLoc(), dimTy, isZero, inDim, dim); + dim = Torch::AtenItemOp::create( + rewriter, binder.getLoc(), rewriter.getType(), finalDim); } dimList.push_back(dim); } - Value dimValueList = rewriter.create( - binder.getLoc(), + Value dimValueList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get( Torch::IntType::get(binder.op->getContext())), dimList); @@ -2232,17 +2252,19 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // convert axes (tensor) into torch int list for (int i = 0; i < shapeSizes[0]; i++) { // Go through the axes list and get each dim in the list - Value selectIndex = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value selectIndex = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - Value extract = rewriter.create( - binder.getLoc(), selectResultType, shape, zero, selectIndex); - Value dim = rewriter.create( - binder.getLoc(), rewriter.getType(), extract); + Value extract = Torch::AtenSelectIntOp::create( + rewriter, binder.getLoc(), selectResultType, shape, zero, + selectIndex); + Value dim = Torch::AtenItemOp::create( + rewriter, binder.getLoc(), rewriter.getType(), + extract); dimList.push_back(dim); } - Value dimValueList = rewriter.create( - binder.getLoc(), + Value dimValueList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), dimList); rewriter.replaceOpWithNewOp(binder.op, resultType, @@ -2277,15 +2299,15 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( int64_t rank = dataTy.getSizes().size(); SmallVector axesList; - Value zero = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value zero = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(0)); // Previous version of the operation had the axes as an attribute: llvm::SmallVector axesAttr; if (!binder.s64IntegerArrayAttr(axesAttr, "axes", {})) { for (int i = 0, s = axesAttr.size(); i < s; ++i) { - axesList.push_back(rewriter.create( - binder.getLoc(), torchIntTy, + axesList.push_back(Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), torchIntTy, rewriter.getI64IntegerAttr(axesAttr[i]))); } } @@ -2300,15 +2322,15 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( cast(axes.getType()); auto sizes = axesType.getSizes(); for (int i = 0; i < sizes[0]; i++) { - Value selectIndex = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i)); - Value extract = rewriter.create( - binder.getLoc(), + Value selectIndex = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(i)); + Value extract = Torch::AtenSelectIntOp::create( + rewriter, binder.getLoc(), axesType.getWithSizesAndDtype(llvm::SmallVector{1}, axesType.getOptionalDtype()), axes, zero, selectIndex); - Value dim = rewriter.create(binder.getLoc(), - torchIntTy, extract); + Value dim = Torch::AtenItemOp::create(rewriter, binder.getLoc(), + torchIntTy, extract); axesList.push_back(dim); } } @@ -2326,39 +2348,40 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // Manually set positive axis to all dims. if (axesList.empty()) { for (int i = 0; i < rank; i++) { - Value dimValue = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i)); + Value dimValue = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(i)); axesList.push_back(dimValue); } } // Handle negative axis: - Value rankVal = rewriter.create(binder.getLoc(), - torchIntTy, data); + Value rankVal = Torch::AtenDimOp::create(rewriter, binder.getLoc(), + torchIntTy, data); for (Value &axes : axesList) { Value isNegative = - rewriter.create(binder.getLoc(), axes, zero); - isNegative = rewriter.create(binder.getLoc(), - isNegative); - Value finalOffset = rewriter.create( - binder.getLoc(), isNegative, rankVal); - axes = rewriter.create(binder.getLoc(), axes, - finalOffset); + Torch::AtenLtIntOp::create(rewriter, binder.getLoc(), axes, zero); + isNegative = Torch::AtenIntBoolOp::create(rewriter, binder.getLoc(), + isNegative); + Value finalOffset = Torch::AtenMulIntOp::create( + rewriter, binder.getLoc(), isNegative, rankVal); + axes = Torch::AtenAddIntOp::create(rewriter, binder.getLoc(), axes, + finalOffset); } // Handle multiple axes case: // ReduceProd on each dim, always set keepDimsBool == True to avoid // segfault. Value trueVal = - rewriter.create(binder.getLoc(), true); - Value noneVal = rewriter.create(binder.getLoc()); + Torch::ConstantBoolOp::create(rewriter, binder.getLoc(), true); + Value noneVal = + Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); SmallVector intermediateShape(rank, Torch::kUnknownSize); Value dataReduceProd = data; for (int i = 0, numAxes = axesList.size(); i < numAxes; i++) { auto axis = axesList[i]; if (keepDims && i == numAxes - 1) { - dataReduceProd = rewriter.create( - binder.getLoc(), + dataReduceProd = Torch::AtenProdDimIntOp::create( + rewriter, binder.getLoc(), dataTy.getWithSizesAndDtype(resultType.getSizes(), dataTy.getOptionalDtype()), dataReduceProd, axis, trueVal, noneVal); @@ -2367,9 +2390,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( } Type resultTyReduceProd = dataTy.getWithSizesAndDtype( ArrayRef(intermediateShape), dataTy.getOptionalDtype()); - dataReduceProd = rewriter.create( - binder.getLoc(), resultTyReduceProd, dataReduceProd, axis, - trueVal, noneVal); + dataReduceProd = Torch::AtenProdDimIntOp::create( + rewriter, binder.getLoc(), resultTyReduceProd, dataReduceProd, + axis, trueVal, noneVal); } // Derived the final shape of the tensor after prod loop of each axis. @@ -2394,14 +2417,13 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // Reshape the prod loop result to the final result shape. SmallVector dataReduceProdShape; for (auto dim : dataReduceProdSize) - dataReduceProdShape.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(dim))); - Value dataReduceProdShapeList = - rewriter.create( - binder.getLoc(), - rewriter.getType( - rewriter.getType()), - dataReduceProdShape); + dataReduceProdShape.push_back(Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(dim))); + Value dataReduceProdShapeList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), + rewriter.getType( + rewriter.getType()), + dataReduceProdShape); rewriter.replaceOpWithNewOp( binder.op, resultType, dataReduceProd, dataReduceProdShapeList); return success(); @@ -2413,7 +2435,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Torch::ValueTensorType resultType; Value start, limit, delta; auto loc = binder.getLoc(); - Value none = rewriter.create(loc); + Value none = Torch::ConstantNoneOp::create(rewriter, loc); if (binder.tensorOperandAtIndex(start, 0) || binder.tensorOperandAtIndex(limit, 1) || binder.tensorOperandAtIndex(delta, 2) || @@ -2470,19 +2492,19 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( llvm::SmallVector dims; int64_t rank = operandTy.getSizes().size(); for (int i = 0; i < rank; ++i) { - auto iv = rewriter.create( - loc, rewriter.getI64IntegerAttr(i)); - Value dim = rewriter.create( - loc, rewriter.getType(), operand, iv); + auto iv = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(i)); + Value dim = Torch::AtenSizeIntOp::create( + rewriter, loc, rewriter.getType(), operand, iv); dims.push_back(dim); } - Value cstFalse = rewriter.create(loc, false); - Value none = rewriter.create(loc); + Value cstFalse = Torch::ConstantBoolOp::create(rewriter, loc, false); + Value none = Torch::ConstantNoneOp::create(rewriter, loc); if (dims.empty()) { - Value one = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); + Value one = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(1)); rewriter.replaceOpWithNewOp( op, resultType, one, none, none, cstFalse); return success(); @@ -2490,7 +2512,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value prod = dims[0]; for (int i = 1, s = dims.size(); i < s; ++i) - prod = rewriter.create(loc, prod, dims[i]); + prod = Torch::AtenMulIntOp::create(rewriter, loc, prod, dims[i]); rewriter.replaceOpWithNewOp( op, resultType, prod, none, none, cstFalse); @@ -2515,21 +2537,23 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( cast(repeatDims.getType()); Type selectResultType = shapeType.getWithSizesAndDtype( llvm::ArrayRef(selectSizes), shapeType.getOptionalDtype()); - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value zero = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); for (int i = 0; i < repeatDimsSizes[0]; i++) { - Value selectIndex = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value selectIndex = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - Value extract = rewriter.create( - binder.getLoc(), selectResultType, repeatDims, zero, selectIndex); - Value dim = rewriter.create( - binder.getLoc(), rewriter.getType(), extract); + Value extract = Torch::AtenSelectIntOp::create( + rewriter, binder.getLoc(), selectResultType, repeatDims, zero, + selectIndex); + Value dim = Torch::AtenItemOp::create( + rewriter, binder.getLoc(), rewriter.getType(), + extract); dimList.push_back(dim); } - Value dimValueList = rewriter.create( - binder.getLoc(), + Value dimValueList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), dimList); @@ -2557,14 +2581,15 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( "Unimplemented: unranked tensor"); unsigned rank = *maybeRank; axis = Torch::toPositiveDim(axis, rank); - Value cstAxis = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(axis)); + Value cstAxis = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(axis)); Value cstLargest = - rewriter.create(binder.getLoc(), largest); + Torch::ConstantBoolOp::create(rewriter, binder.getLoc(), largest); Value cstSorted = - rewriter.create(binder.getLoc(), sorted); - Value kValueInt = rewriter.create( - binder.getLoc(), rewriter.getType(), kValue); + Torch::ConstantBoolOp::create(rewriter, binder.getLoc(), sorted); + Value kValueInt = Torch::AtenItemOp::create( + rewriter, binder.getLoc(), rewriter.getType(), + kValue); rewriter.replaceOpWithNewOp( binder.op, Values_type, Indices_type, input, kValueInt, cstAxis, cstLargest, cstSorted); @@ -2591,34 +2616,34 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return failure(); } // out = ln(exp(x) + 1) - Value exp = rewriter.create(binder.getLoc(), - resultType, input); + Value exp = Torch::AtenExpOp::create(rewriter, binder.getLoc(), + resultType, input); rewriter.replaceOpWithNewOp(binder.op, resultType, exp); return success(); }); - patterns.onOp("Softsign", 22, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value input; - if (binder.tensorOperand(input) || - binder.tensorResultType(resultType)) { - return failure(); - } + patterns.onOp( + "Softsign", 22, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input; + if (binder.tensorOperand(input) || + binder.tensorResultType(resultType)) { + return failure(); + } - Value absX = rewriter.create( - binder.getLoc(), resultType, input); + Value absX = Torch::AtenAbsOp::create(rewriter, binder.getLoc(), + resultType, input); - Value constOne = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value constOne = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(1)); - Value absXPlusOne = rewriter.create( - binder.getLoc(), resultType, absX, constOne, constOne); + Value absXPlusOne = Torch::AtenAddScalarOp::create( + rewriter, binder.getLoc(), resultType, absX, constOne, constOne); - rewriter.replaceOpWithNewOp( - binder.op, resultType, input, absXPlusOne); - return success(); - }); + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, absXPlusOne); + return success(); + }); patterns.onOp( "Trilu", 14, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -2632,11 +2657,12 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value diagonal; if (binder.tensorOperandAtIndex(diagonal, 1)) { - diagonal = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(0)); + diagonal = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(0)); } else { - diagonal = rewriter.create( - binder.getLoc(), rewriter.getType(), diagonal); + diagonal = Torch::AtenItemOp::create( + rewriter, binder.getLoc(), rewriter.getType(), + diagonal); } if (upper) { @@ -2648,26 +2674,27 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( input, diagonal); return success(); }); - patterns.onOp("ThresholdedRelu", 10, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value input; - float alpha; - if (binder.tensorOperand(input) || - binder.f32FloatAttr(alpha, "alpha", 1.0) || - binder.tensorResultType(resultType)) { - return failure(); - } - Value cstAlpha = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getFloatAttr(rewriter.getF64Type(), alpha)); - Value value = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getFloatAttr(rewriter.getF64Type(), 0.0)); - rewriter.replaceOpWithNewOp( - binder.op, resultType, input, cstAlpha, value); - return success(); - }); + patterns.onOp( + "ThresholdedRelu", 10, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input; + float alpha; + if (binder.tensorOperand(input) || + binder.f32FloatAttr(alpha, "alpha", 1.0) || + binder.tensorResultType(resultType)) { + return failure(); + } + Value cstAlpha = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), alpha)); + Value value = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), 0.0)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, cstAlpha, value); + return success(); + }); patterns.onOp( "RandomNormal", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { @@ -2697,24 +2724,26 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, "unimplemented support for the given dtype conversion"); } - Value constDtype = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value())); + Value constDtype = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(dtypeIntTorch.value())); Value shapeList = createConstantIntList(binder, rewriter, shape); - Value cstNone = rewriter.create(binder.getLoc()); + Value cstNone = + Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); - Value self = rewriter.create( - binder.op->getLoc(), resultType, shapeList, + Value self = Torch::AtenEmptyMemoryFormatOp::create( + rewriter, binder.op->getLoc(), resultType, shapeList, /*dtype=*/constDtype, /*layout=*/cstNone, /*device=*/cstNone, /*pinMemory=*/cstNone, /*memoryFormat=*/cstNone); - Value cstMean = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value cstMean = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), mean)); - Value cstStd = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value cstStd = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), scale)); rewriter.replaceOpWithNewOp( @@ -2752,22 +2781,24 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, "unimplemented support for the given dtype conversion"); } - Value constDtype = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value())); + Value constDtype = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(dtypeIntTorch.value())); - Value cstNone = rewriter.create(binder.getLoc()); + Value cstNone = + Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); Value cstFalse = - rewriter.create(binder.getLoc(), false); - input = rewriter.create( - binder.op->getLoc(), resultType, input, constDtype, + Torch::ConstantBoolOp::create(rewriter, binder.getLoc(), false); + input = Torch::AtenToDtypeOp::create( + rewriter, binder.op->getLoc(), resultType, input, constDtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/cstNone); - Value cstMean = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value cstMean = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), mean)); - Value cstStd = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value cstStd = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), scale)); rewriter.replaceOpWithNewOp( @@ -2804,24 +2835,26 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, "unimplemented support for the given dtype conversion"); } - Value constDtype = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value())); + Value constDtype = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(dtypeIntTorch.value())); Value shapeList = createConstantIntList(binder, rewriter, shape); - Value cstNone = rewriter.create(binder.getLoc()); + Value cstNone = + Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); - Value self = rewriter.create( - binder.op->getLoc(), resultType, shapeList, + Value self = Torch::AtenEmptyMemoryFormatOp::create( + rewriter, binder.op->getLoc(), resultType, shapeList, /*dtype=*/constDtype, /*layout=*/cstNone, /*device=*/cstNone, /*pinMemory=*/cstNone, /*memoryFormat=*/cstNone); - Value cstHigh = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value cstHigh = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), high)); - Value cstLow = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value cstLow = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), low)); rewriter.replaceOpWithNewOp( @@ -2859,22 +2892,24 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, "unimplemented support for the given dtype conversion"); } - Value constDtype = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value())); + Value constDtype = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(dtypeIntTorch.value())); - Value cstNone = rewriter.create(binder.getLoc()); + Value cstNone = + Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); Value cstFalse = - rewriter.create(binder.getLoc(), false); - input = rewriter.create( - binder.op->getLoc(), resultType, input, constDtype, + Torch::ConstantBoolOp::create(rewriter, binder.getLoc(), false); + input = Torch::AtenToDtypeOp::create( + rewriter, binder.op->getLoc(), resultType, input, constDtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/cstNone); - Value cstHigh = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value cstHigh = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), high)); - Value cstLow = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value cstLow = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), low)); rewriter.replaceOpWithNewOp( @@ -2899,26 +2934,27 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( } if (binder.tensorOperandAtIndex(weight, 2)) - weight = rewriter.create(binder.getLoc()); + weight = Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); - Value cstIgnoreIndex = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(ignoreIndex)); + Value cstIgnoreIndex = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(ignoreIndex)); int64_t reductionInt = reduction == "none" ? 0 : reduction == "mean" ? 1 : 2; - Value cstReductionInt = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(reductionInt)); + Value cstReductionInt = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(reductionInt)); // The default PyTorch value for label smoothing is "0.0". // Refer: // https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html - Value cstLabelSmoothing = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value cstLabelSmoothing = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), 0.0)); - Value loss = rewriter.create( - binder.getLoc(), resultType, scores, labels, weight, + Value loss = Torch::AtenCrossEntropyLossOp::create( + rewriter, binder.getLoc(), resultType, scores, labels, weight, cstReductionInt, cstIgnoreIndex, cstLabelSmoothing); if (binder.op->getNumResults() == 1) { @@ -2930,11 +2966,13 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( if (binder.tensorResultTypeAtIndex(resultTypeLogProb, 1)) return failure(); - Value dim = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); - Value cstNone = rewriter.create(binder.getLoc()); - Value logProb = rewriter.create( - binder.getLoc(), resultTypeLogProb, scores, dim, /*dtype=*/cstNone); + Value dim = Torch::ConstantIntOp::create(rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(1)); + Value cstNone = + Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); + Value logProb = Torch::AtenLogSoftmaxIntOp::create( + rewriter, binder.getLoc(), resultTypeLogProb, scores, dim, + /*dtype=*/cstNone); rewriter.replaceOp(binder.op, {loss, logProb}); return success(); @@ -3045,8 +3083,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( auto loc = binder.getLoc(); - Value cstFalse = rewriter.create(loc, false); - Value cstTrue = rewriter.create(loc, true); + Value cstFalse = Torch::ConstantBoolOp::create(rewriter, loc, false); + Value cstTrue = Torch::ConstantBoolOp::create(rewriter, loc, true); Value modeStrValue; Value alignCorners = @@ -3055,7 +3093,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( std::string modeStr = "cubic"; if (coordTfMode != "half_pixel") modeStr = modeStr + "_" + coordTfMode; - modeStrValue = rewriter.create(loc, modeStr); + modeStrValue = Torch::ConstantStrOp::create(rewriter, loc, modeStr); } auto rankOfInputTensor = sizesOfInputTensor.size(); @@ -3084,7 +3122,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // mode is apparently half_pixel, NOT pytorch_half_pixel if (coordTfMode != "half_pixel" && coordTfMode != "align_corners") modeStr = (modeStr + "_") + coordTfMode; - modeStrValue = rewriter.create(loc, modeStr); + modeStrValue = Torch::ConstantStrOp::create(rewriter, loc, modeStr); } if (mode == "nearest") { std::string modeStr = "nearest"; @@ -3094,7 +3132,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( modeStr = (modeStr + "_") + coordTfMode; if (nearest_mode != "floor" && nearest_mode != "") modeStr = modeStr + "," + nearest_mode; - modeStrValue = rewriter.create(loc, modeStr); + modeStrValue = Torch::ConstantStrOp::create(rewriter, loc, modeStr); } auto numberOfOperands = operands.size(); @@ -3106,29 +3144,29 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value supportedScaleFactors; Value supportedSizes; - Value noneVal = rewriter.create(loc); + Value noneVal = Torch::ConstantNoneOp::create(rewriter, loc); if (numberOfOperands == 3) { Value proposedScaleFactors = operands[2]; - Value scaleIdentity = rewriter.create( - loc, rewriter.getF64FloatAttr(1.0)); + Value scaleIdentity = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(1.0)); // run-time scale factor check for dynamic sizes for (auto &eachDim : nonResizableDims) { Value eachProposedScaleFactor = extractTorchScalar( loc, eachDim, proposedScaleFactors, rewriter); - Value eachScaleFactorIsIdentity = - rewriter.create( - loc, boolType, eachProposedScaleFactor, scaleIdentity); + Value eachScaleFactorIsIdentity = Torch::AtenEqFloatOp::create( + rewriter, loc, boolType, eachProposedScaleFactor, + scaleIdentity); auto errorMessageForEachDim = "Unsupported: non-trivial scale factor for dimension " + std::to_string(eachDim); - rewriter.create( - loc, eachScaleFactorIsIdentity, + Torch::RuntimeAssertOp::create( + rewriter, loc, eachScaleFactorIsIdentity, rewriter.getStringAttr(errorMessageForEachDim)); }; @@ -3141,24 +3179,24 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // run-time target size check for dynamic sizes for (auto &eachDimAsInt : nonResizableDims) { Value eachDimAsValue = - rewriter.create(loc, eachDimAsInt); + Torch::ConstantIntOp::create(rewriter, loc, eachDimAsInt); - Value eachSizeOfInputTensor = rewriter.create( - loc, inputTensor, eachDimAsValue); + Value eachSizeOfInputTensor = Torch::AtenSizeIntOp::create( + rewriter, loc, inputTensor, eachDimAsValue); Value eachProposedSize = extractTorchScalar(loc, eachDimAsInt, proposedSizes, rewriter); - Value eachProposedSizeIsTrivial = - rewriter.create( - loc, boolType, eachProposedSize, eachSizeOfInputTensor); + Value eachProposedSizeIsTrivial = Torch::AtenEqIntOp::create( + rewriter, loc, boolType, eachProposedSize, + eachSizeOfInputTensor); auto errorMessageForEachDim = "Unsupported: non-trivial resizing of dimension " + std::to_string(eachDimAsInt); - rewriter.create( - loc, eachProposedSizeIsTrivial, + Torch::RuntimeAssertOp::create( + rewriter, loc, eachProposedSizeIsTrivial, rewriter.getStringAttr(errorMessageForEachDim)); }; @@ -3213,8 +3251,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( auto roisType = dyn_cast(rois.getType()); if (!roisType || !roisType.hasSizes()) return failure(); - Value cstDim = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value cstDim = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(1)); FailureOr unsqueezeIndices = Torch::unsqueezeTensor(rewriter, binder.op, batchIndices, cstDim); if (failed(unsqueezeIndices)) @@ -3224,11 +3262,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( cast(batchIndices.getType()); Value dTypeInt = Torch::getDtypeIntValueForType(rewriter, loc, roisType.getDtype()); - Value none = rewriter.create(binder.getLoc()); + Value none = Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); Value cstFalse = - rewriter.create(binder.getLoc(), false); - Value newBatchIndices = rewriter.create( - loc, + Torch::ConstantBoolOp::create(rewriter, binder.getLoc(), false); + Value newBatchIndices = Torch::AtenToDtypeOp::create( + rewriter, loc, batchIndicesType.getWithSizesAndDtype( batchIndicesType.getOptionalSizes(), roisType.getOptionalDtype()), @@ -3241,24 +3279,26 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( roisType.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt, /*optionalDtype=*/nullptr); Type listType = Torch::ListType::get(listElemType); - Value tensorList = rewriter.create( - binder.op->getLoc(), listType, ValueRange{newBatchIndices, rois}); - Value newRois = - rewriter.create(loc, catType, tensorList, cstDim); + Value tensorList = Torch::PrimListConstructOp::create( + rewriter, binder.op->getLoc(), listType, + ValueRange{newBatchIndices, rois}); + Value newRois = Torch::AtenCatOp::create(rewriter, loc, catType, + tensorList, cstDim); // make constants from attributes - Value cstSpatialScale = rewriter.create( - loc, rewriter.getF64FloatAttr(spatialScaleFloat)); - Value pooledHeight = rewriter.create( - loc, rewriter.getI64IntegerAttr(outHInt)); - Value pooledWidth = rewriter.create( - loc, rewriter.getI64IntegerAttr(outWInt)); + Value cstSpatialScale = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(spatialScaleFloat)); + Value pooledHeight = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(outHInt)); + Value pooledWidth = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(outWInt)); // this is for consistency with the default pytorch sampling ratio value samplingRatioInt = (samplingRatioInt == 0) ? -1 : samplingRatioInt; - Value samplingRatio = rewriter.create( - loc, rewriter.getI64IntegerAttr(samplingRatioInt)); + Value samplingRatio = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(samplingRatioInt)); bool aligned = coordTfMode == "half_pixel"; - Value cstAligned = rewriter.create(loc, aligned); + Value cstAligned = + Torch::ConstantBoolOp::create(rewriter, loc, aligned); if (mode == "avg") { rewriter.replaceOpWithNewOp( @@ -3269,8 +3309,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // mode == "max" auto indicesType = resultType.getWithSizesAndDtype( resultType.getOptionalSizes(), batchIndicesType.getDtype()); - auto roiPool = rewriter.create( - loc, TypeRange{resultType, indicesType}, input, newRois, + auto roiPool = Torch::TorchvisionRoiPoolOp::create( + rewriter, loc, TypeRange{resultType, indicesType}, input, newRois, cstSpatialScale, pooledHeight, pooledWidth); rewriter.replaceOp(binder.op, roiPool.getResult(0)); return success(); @@ -3298,34 +3338,35 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( "Expected input rank to be 4"); } - Value b = rewriter.create( - binder.getLoc(), input, - rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(0))); - Value c = rewriter.create( - binder.getLoc(), input, - rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1))); - Value h = rewriter.create( - binder.getLoc(), input, - rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(2))); - Value w = rewriter.create( - binder.getLoc(), input, - rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(3))); - Value cstBlockSize = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(blockSize)); - Value cstBlockSizeSquare = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(blockSize * blockSize)); - Value hDivBlockSize = rewriter.create( - binder.getLoc(), h, cstBlockSize); - Value wDivBlockSize = rewriter.create( - binder.getLoc(), w, cstBlockSize); - hDivBlockSize = rewriter.create(binder.getLoc(), - hDivBlockSize); - wDivBlockSize = rewriter.create(binder.getLoc(), - wDivBlockSize); + Value b = Torch::AtenSizeIntOp::create( + rewriter, binder.getLoc(), input, + Torch::ConstantIntOp::create(rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(0))); + Value c = Torch::AtenSizeIntOp::create( + rewriter, binder.getLoc(), input, + Torch::ConstantIntOp::create(rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(1))); + Value h = Torch::AtenSizeIntOp::create( + rewriter, binder.getLoc(), input, + Torch::ConstantIntOp::create(rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(2))); + Value w = Torch::AtenSizeIntOp::create( + rewriter, binder.getLoc(), input, + Torch::ConstantIntOp::create(rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(3))); + Value cstBlockSize = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(blockSize)); + Value cstBlockSizeSquare = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(blockSize * blockSize)); + Value hDivBlockSize = Torch::AtenDivIntOp::create( + rewriter, binder.getLoc(), h, cstBlockSize); + Value wDivBlockSize = Torch::AtenDivIntOp::create( + rewriter, binder.getLoc(), w, cstBlockSize); + hDivBlockSize = Torch::AtenIntFloatOp::create(rewriter, binder.getLoc(), + hDivBlockSize); + wDivBlockSize = Torch::AtenIntFloatOp::create(rewriter, binder.getLoc(), + wDivBlockSize); // The implementation is as follows: // tmp = np.reshape( @@ -3334,8 +3375,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // tmp = np.transpose(tmp, [0, 3, 5, 1, 2, 4]) // y = np.reshape(tmp, [b, c * (blocksize**2), h // blocksize, w // // blocksize]) - Value reshapeSizesList = rewriter.create( - binder.getLoc(), + Value reshapeSizesList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get(Torch::IntType::get(input.getContext())), llvm::SmallVector{b, c, hDivBlockSize, cstBlockSize, wDivBlockSize, cstBlockSize}); @@ -3348,8 +3389,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( SmallVector reshapeSizesInt{inputSizes[0], inputSizes[1], hDivBlockSizeInt, blockSize, wDivBlockSizeInt, blockSize}; - Value reshapedInput = rewriter.create( - binder.getLoc(), + Value reshapedInput = Torch::AtenReshapeOp::create( + rewriter, binder.getLoc(), inputTy.getWithSizesAndDtype(reshapeSizesInt, inputTy.getOptionalDtype()), input, reshapeSizesList); @@ -3362,10 +3403,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return rewriter.notifyMatchFailure( binder.op, "Failed to create Torch Permute op"); - Value cMulBlockSizeSquare = rewriter.create( - binder.getLoc(), c, cstBlockSizeSquare); - reshapeSizesList = rewriter.create( - binder.getLoc(), + Value cMulBlockSizeSquare = Torch::AtenMulIntOp::create( + rewriter, binder.getLoc(), c, cstBlockSizeSquare); + reshapeSizesList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get(Torch::IntType::get(input.getContext())), llvm::SmallVector{b, cMulBlockSizeSquare, hDivBlockSize, wDivBlockSize}); @@ -3416,29 +3457,29 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // ZeroCast) output = Where (InputLessThanNegLambda, InputAddBias, // InputSubBiasOrZero) // } - Value constLambd = rewriter.create( - loc, rewriter.getFloatAttr(rewriter.getF64Type(), lambd)); - Value constBias = rewriter.create( - loc, rewriter.getFloatAttr(rewriter.getF64Type(), bias)); - Value constZero = rewriter.create( - loc, rewriter.getFloatAttr(rewriter.getF64Type(), 0.0)); - Value constOne = rewriter.create( - loc, rewriter.getFloatAttr(rewriter.getF64Type(), 1.0)); - Value constNegLambd = rewriter.create( - loc, rewriter.getFloatAttr(rewriter.getF64Type(), -lambd)); - - Value inputLTNegLambd = rewriter.create( - loc, comparisonResultType, input, constNegLambd); - Value inputPlusBias = rewriter.create( - loc, inputType, input, constBias, /*alpha=*/constOne); - Value inputSubBias = rewriter.create( - loc, inputType, input, constBias, /*alpha=*/constOne); - Value inputGTLambd = rewriter.create( - loc, comparisonResultType, input, constLambd); - - Value inputSubBiasOrZero = - rewriter.create( - loc, resultType, inputGTLambd, inputSubBias, constZero); + Value constLambd = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getFloatAttr(rewriter.getF64Type(), lambd)); + Value constBias = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getFloatAttr(rewriter.getF64Type(), bias)); + Value constZero = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getFloatAttr(rewriter.getF64Type(), 0.0)); + Value constOne = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getFloatAttr(rewriter.getF64Type(), 1.0)); + Value constNegLambd = Torch::ConstantFloatOp::create( + rewriter, loc, + rewriter.getFloatAttr(rewriter.getF64Type(), -lambd)); + + Value inputLTNegLambd = Torch::AtenLtScalarOp::create( + rewriter, loc, comparisonResultType, input, constNegLambd); + Value inputPlusBias = Torch::AtenAddScalarOp::create( + rewriter, loc, inputType, input, constBias, /*alpha=*/constOne); + Value inputSubBias = Torch::AtenSubScalarOp::create( + rewriter, loc, inputType, input, constBias, /*alpha=*/constOne); + Value inputGTLambd = Torch::AtenGtScalarOp::create( + rewriter, loc, comparisonResultType, input, constLambd); + + Value inputSubBiasOrZero = Torch::AtenWhereScalarOtherOp::create( + rewriter, loc, resultType, inputGTLambd, inputSubBias, constZero); rewriter.replaceOpWithNewOp( binder.op, resultType, inputLTNegLambd, inputPlusBias, inputSubBiasOrZero); @@ -3453,46 +3494,48 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.tensorResultType(resultType)) return failure(); - Value index = rewriter.create( - binder.getLoc(), rewriter.getType(), - position); + Value index = Torch::AtenItemOp::create( + rewriter, binder.getLoc(), + rewriter.getType(), position); rewriter.replaceOpWithNewOp( binder.op, resultType, inputSequence, index); return success(); }); - patterns.onOp( - "SequenceEmpty", 11, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ListType resultType; - int64_t dtypeIntOnnx; - if (binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) || - binder.tensorListResultType(resultType)) - return failure(); - - std::optional dtypeIntTorch = - onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); - if (!dtypeIntTorch.has_value()) { - return rewriter.notifyMatchFailure( - binder.op, - "unimplemented support for the given dtype conversion"); - } - Value constDtype = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(dtypeIntTorch.value())); - - Value shapeList = createConstantIntList(binder, rewriter, {}); - Value cstNone = rewriter.create(binder.getLoc()); - - Value self = rewriter.create( - binder.op->getLoc(), resultType.getContainedType(), shapeList, - /*dtype=*/constDtype, - /*layout=*/cstNone, - /*device=*/cstNone, /*pinMemory=*/cstNone, - /*memoryFormat=*/cstNone); + patterns.onOp("SequenceEmpty", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ListType resultType; + int64_t dtypeIntOnnx; + if (binder.s64IntegerAttr(dtypeIntOnnx, "dtype", 1) || + binder.tensorListResultType(resultType)) + return failure(); - rewriter.replaceOpWithNewOp( - binder.op, resultType, llvm::SmallVector{self}); - return success(); - }); + std::optional dtypeIntTorch = + onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); + if (!dtypeIntTorch.has_value()) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented support for the given dtype conversion"); + } + Value constDtype = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(dtypeIntTorch.value())); + + Value shapeList = createConstantIntList(binder, rewriter, {}); + Value cstNone = + Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); + + Value self = Torch::AtenEmptyMemoryFormatOp::create( + rewriter, binder.op->getLoc(), + resultType.getContainedType(), shapeList, + /*dtype=*/constDtype, + /*layout=*/cstNone, + /*device=*/cstNone, /*pinMemory=*/cstNone, + /*memoryFormat=*/cstNone); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, llvm::SmallVector{self}); + return success(); + }); patterns.onOp( "SequenceErase", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { @@ -3502,17 +3545,19 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.tensorListResultType(resultType)) return failure(); - Value length = rewriter.create( - binder.getLoc(), rewriter.getType(), inputSequence); + Value length = Torch::AtenLenTOp::create( + rewriter, binder.getLoc(), rewriter.getType(), + inputSequence); - Value cstNone = rewriter.create(binder.getLoc()); - Value cstOne = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value cstNone = + Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); + Value cstOne = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(1)); if (binder.op->getNumOperands() == 1) { // If True, it means that the `position` arg is missing and // the last tensor from the list has to be erased. - Value lengthMinusOne = rewriter.create( - binder.getLoc(), length, cstOne); + Value lengthMinusOne = Torch::AtenSubIntOp::create( + rewriter, binder.getLoc(), length, cstOne); rewriter.replaceOpWithNewOp( binder.op, resultType, inputSequence, /*start=*/cstNone, /*end=*/lengthMinusOne, /*step=*/cstOne); @@ -3522,27 +3567,29 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( if (binder.tensorOperandAtIndex(position, 1)) return failure(); - Value positionInt = rewriter.create( - binder.getLoc(), rewriter.getType(), position); + Value positionInt = Torch::AtenItemOp::create( + rewriter, binder.getLoc(), rewriter.getType(), + position); // Handling negative position value. - Value cstZero = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(0)); - Value isPositionNegative = rewriter.create( - binder.getLoc(), positionInt, cstZero); - isPositionNegative = rewriter.create( - binder.getLoc(), isPositionNegative); - Value finalOffset = rewriter.create( - binder.getLoc(), isPositionNegative, length); - positionInt = rewriter.create( - binder.getLoc(), positionInt, finalOffset); - - Value listBeforePosition = rewriter.create( - binder.getLoc(), resultType, inputSequence, /*start=*/cstNone, - /*end=*/positionInt, /*step=*/cstOne); - Value positionPlusOne = rewriter.create( - binder.getLoc(), positionInt, cstOne); - Value listAfterPosition = rewriter.create( - binder.getLoc(), resultType, inputSequence, + Value cstZero = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value isPositionNegative = Torch::AtenLtIntOp::create( + rewriter, binder.getLoc(), positionInt, cstZero); + isPositionNegative = Torch::AtenIntBoolOp::create( + rewriter, binder.getLoc(), isPositionNegative); + Value finalOffset = Torch::AtenMulIntOp::create( + rewriter, binder.getLoc(), isPositionNegative, length); + positionInt = Torch::AtenAddIntOp::create(rewriter, binder.getLoc(), + positionInt, finalOffset); + + Value listBeforePosition = + Torch::AtenSliceTOp::create(rewriter, binder.getLoc(), resultType, + inputSequence, /*start=*/cstNone, + /*end=*/positionInt, /*step=*/cstOne); + Value positionPlusOne = Torch::AtenAddIntOp::create( + rewriter, binder.getLoc(), positionInt, cstOne); + Value listAfterPosition = Torch::AtenSliceTOp::create( + rewriter, binder.getLoc(), resultType, inputSequence, /*start=*/positionPlusOne, /*end=*/length, /*step=*/cstOne); @@ -3550,39 +3597,40 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, resultType, listBeforePosition, listAfterPosition); return success(); }); - patterns.onOp( - "SequenceInsert", 11, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ListType resultType; - Value inputSequence, position, insertValue; - if (binder.tensorListOperandAtIndex(inputSequence, 0) || - binder.tensorOperandAtIndex(insertValue, 1) || - binder.tensorListResultType(resultType)) - return failure(); + patterns.onOp("SequenceInsert", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ListType resultType; + Value inputSequence, position, insertValue; + if (binder.tensorListOperandAtIndex(inputSequence, 0) || + binder.tensorOperandAtIndex(insertValue, 1) || + binder.tensorListResultType(resultType)) + return failure(); - if (binder.op->getNumOperands() == 1) { - // If True, it means that the `position` arg is missing and - // the tensor has to be inserted at the end of the list. - Value length = rewriter.create( - binder.getLoc(), rewriter.getType(), - inputSequence); - rewriter.replaceOpWithNewOp( - binder.op, inputSequence, /*idx=*/length, - /*el=*/insertValue); - return success(); - } + if (binder.op->getNumOperands() == 1) { + // If True, it means that the `position` arg is missing and + // the tensor has to be inserted at the end of the list. + Value length = Torch::AtenLenTOp::create( + rewriter, binder.getLoc(), + rewriter.getType(), inputSequence); + rewriter.replaceOpWithNewOp( + binder.op, inputSequence, /*idx=*/length, + /*el=*/insertValue); + return success(); + } - if (binder.tensorOperandAtIndex(position, 2)) - return failure(); + if (binder.tensorOperandAtIndex(position, 2)) + return failure(); - Value positionInt = rewriter.create( - binder.getLoc(), rewriter.getType(), position); - rewriter.create(binder.getLoc(), inputSequence, - /*idx=*/positionInt, - /*el=*/insertValue); - rewriter.replaceOp(binder.op, inputSequence); - return success(); - }); + Value positionInt = Torch::AtenItemOp::create( + rewriter, binder.getLoc(), + rewriter.getType(), position); + Torch::AtenInsertTOp::create(rewriter, binder.getLoc(), + inputSequence, + /*idx=*/positionInt, + /*el=*/insertValue); + rewriter.replaceOp(binder.op, inputSequence); + return success(); + }); patterns.onOp( "SequenceMap", 17, [](OpBinder binder, ConversionPatternRewriter &rewriter) { @@ -3604,23 +3652,27 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( dyn_cast(resultType.getContainedType()); Value shapeList = createConstantIntList(binder, rewriter, resultTensorType.getSizes()); - Value cstNone = rewriter.create(binder.getLoc()); - Value self = rewriter.create( - binder.op->getLoc(), resultType.getContainedType(), shapeList, + Value cstNone = + Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); + Value self = Torch::AtenEmptyMemoryFormatOp::create( + rewriter, binder.op->getLoc(), resultType.getContainedType(), + shapeList, /*dtype=*/cstNone, /*layout=*/cstNone, /*device=*/cstNone, /*pinMemory=*/cstNone, /*memoryFormat=*/cstNone); - Value result = rewriter.create( - binder.getLoc(), resultType, llvm::SmallVector{self}); + Value result = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), resultType, + llvm::SmallVector{self}); // create a for-like primLoopOp // with the length of sequence as max iter_num - Value len = rewriter.create( - binder.getLoc(), rewriter.getType(), operands[0]); - auto cstTrue = rewriter.create( - binder.getLoc(), rewriter.getBoolAttr(true)); + Value len = Torch::AtenLenTOp::create( + rewriter, binder.getLoc(), rewriter.getType(), + operands[0]); + auto cstTrue = Torch::ConstantBoolOp::create( + rewriter, binder.getLoc(), rewriter.getBoolAttr(true)); mlir::ImplicitLocOpBuilder b(binder.getLoc(), rewriter); auto loop = - b.create(resultType, len, cstTrue, result); + Torch::PrimLoopOp::create(b, resultType, len, cstTrue, result); rewriter.cloneRegionBefore(*bodyRegion, loop.getRegion(), loop.getRegion().begin()); @@ -3641,8 +3693,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( auto tensorType = dyn_cast( dyn_cast(operands[i].getType()) .getContainedType()); - Value item = rewriter.create( - binder.getLoc(), tensorType, operands[i], indexArg); + Value item = Torch::Aten__Getitem__TOp::create( + rewriter, binder.getLoc(), tensorType, operands[i], indexArg); argInput.replaceAllUsesWith(item); } else { argInput.replaceAllUsesWith(operands[i]); @@ -3656,8 +3708,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( rewriter.setInsertionPoint(terminator); // update sequence input auto terminatorOperands = terminator->getOperands(); - Value append = rewriter.create( - binder.getLoc(), resultType, sequenceArg, terminatorOperands[0]); + Value append = + Torch::AtenAppendTOp::create(rewriter, binder.getLoc(), resultType, + sequenceArg, terminatorOperands[0]); rewriter.replaceOpWithNewOp( terminator, cstTrue, append); @@ -3695,10 +3748,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( mode = "trilinear"; } Value modeStrValue = - rewriter.create(binder.getLoc(), mode); - Value cstNone = rewriter.create(binder.getLoc()); - Value cstFalse = rewriter.create( - binder.getLoc(), rewriter.getBoolAttr(false)); + Torch::ConstantStrOp::create(rewriter, binder.getLoc(), mode); + Value cstNone = + Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); + Value cstFalse = Torch::ConstantBoolOp::create( + rewriter, binder.getLoc(), rewriter.getBoolAttr(false)); rewriter .replaceOpWithNewOp( @@ -3780,11 +3834,13 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( } if (frameLengthIsNone) { if (windowIsNone) { - frameLength = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(signalShape[1])); + frameLength = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(signalShape[1])); } else { - frameLength = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(windowShape[0])); + frameLength = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(windowShape[0])); } } @@ -3802,14 +3858,15 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( auto onesResultTy = rewriter.getType( ArrayRef({-1}), signalTy.getDtype()); - Value none = rewriter.create(binder.getLoc()); - Value sizes = rewriter.create( - binder.getLoc(), + Value none = Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); + Value sizes = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get( Torch::IntType::get(binder.op->getContext())), SmallVector{frameLengthItem}); - window = rewriter.create( - binder.getLoc(), onesResultTy, sizes, none, none, none, none); + window = + Torch::AtenOnesOp::create(rewriter, binder.getLoc(), onesResultTy, + sizes, none, none, none, none); } FailureOr complexDtype; @@ -3847,16 +3904,16 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // into torch.stft. if (signalShape.size() == 3) { if (signalShape[2] == 2) { - signal = rewriter.create( - binder.getLoc(), complexSignalTy, signal); + signal = Torch::AtenViewAsComplexOp::create( + rewriter, binder.getLoc(), complexSignalTy, signal); } else { - Value two = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(2)); + Value two = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(2)); auto newSignalTy = signalTy.getWithSizesAndDtype( ArrayRef({signalShape[0], signalShape[1]}), signalTy.getDtype()); - signal = rewriter.create( - binder.getLoc(), newSignalTy, signal, two); + signal = Torch::AtenSqueezeDimOp::create(rewriter, binder.getLoc(), + newSignalTy, signal, two); } } @@ -3864,18 +3921,19 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // as the length of the window. Value windowLen; if (!windowIsNone) { - windowLen = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(windowShape[0])); + windowLen = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(windowShape[0])); } else { windowLen = frameLengthItem; } Value falseVal = - rewriter.create(binder.getLoc(), false); + Torch::ConstantBoolOp::create(rewriter, binder.getLoc(), false); Value trueVal = - rewriter.create(binder.getLoc(), true); + Torch::ConstantBoolOp::create(rewriter, binder.getLoc(), true); Value padMode = - rewriter.create(binder.getLoc(), "reflect"); + Torch::ConstantStrOp::create(rewriter, binder.getLoc(), "reflect"); auto stftTy = complexSignalTy.getWithSizesAndDtype( ArrayRef({resultShape[0], resultShape[2], resultShape[1]}), complexSignalTy.getDtype()); @@ -3888,17 +3946,17 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // shape of stft to match the shape of resultType. Also, it is // immaterial whether torch.view_as_real is called after or before the // permutation; both outputs will be equivalent. - Value stft = rewriter.create( - binder.getLoc(), stftTy, signal, frameLengthItem, frameStepItem, - windowLen, window, falseVal, padMode, falseVal, + Value stft = Torch::AtenStftCenterOp::create( + rewriter, binder.getLoc(), stftTy, signal, frameLengthItem, + frameStepItem, windowLen, window, falseVal, padMode, falseVal, onesided ? trueVal : falseVal, trueVal, falseVal); auto permuteStftTy = complexSignalTy.getWithSizesAndDtype( ArrayRef({resultShape[0], resultShape[1], resultShape[2]}), complexSignalTy.getDtype()); Value permuteDims = createConstantIntList(binder, rewriter, {0, 2, 1}); - Value permutedStft = rewriter.create( - binder.getLoc(), permuteStftTy, stft, permuteDims); + Value permutedStft = Torch::AtenPermuteOp::create( + rewriter, binder.getLoc(), permuteStftTy, stft, permuteDims); rewriter.replaceOpWithNewOp( binder.op, resultType, permutedStft); @@ -3921,14 +3979,14 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( SmallVector inputShape(inputTy.getSizes()); auto dtype = resultType.getDtype(); - Value cstZero = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(0)); - Value cstOne = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); - Value batchAxisVal = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(batchAxis)); - Value timeAxisVal = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(timeAxis)); + Value cstZero = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value cstOne = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value batchAxisVal = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(batchAxis)); + Value timeAxisVal = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(timeAxis)); SmallVector sliceShape(inputShape); sliceShape[batchAxis] = 1; @@ -3943,39 +4001,43 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( for (int i = 0; i < inputShape[batchAxis]; i++) { // slice i iterating on batch axis - Value k = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i)); + Value k = Torch::ConstantIntOp::create(rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(i)); Value end = - rewriter.create(binder.getLoc(), k, cstOne); - Value sliceBatch = rewriter.create( - binder.getLoc(), sliceType, input, batchAxisVal, k, end, cstOne); + Torch::AtenAddIntOp::create(rewriter, binder.getLoc(), k, cstOne); + Value sliceBatch = Torch::AtenSliceTensorOp::create( + rewriter, binder.getLoc(), sliceType, input, batchAxisVal, k, end, + cstOne); // get sequence length and slice the reversing part - Value kTensor = rewriter.create( - binder.getLoc(), scalarTensorType, k); - Value sel = rewriter.create( - binder.getLoc(), scalarTensorType, sequenceLens, cstZero, - kTensor); - Value len = rewriter.create( - binder.getLoc(), rewriter.getType(), sel); - Value sliceTime = rewriter.create( - binder.getLoc(), flipType, sliceBatch, timeAxisVal, cstZero, len, - cstOne); + Value kTensor = Torch::PrimNumToTensorScalarOp::create( + rewriter, binder.getLoc(), scalarTensorType, k); + Value sel = Torch::AtenIndexSelectOp::create( + rewriter, binder.getLoc(), scalarTensorType, sequenceLens, + cstZero, kTensor); + Value len = Torch::AtenItemOp::create( + rewriter, binder.getLoc(), rewriter.getType(), + sel); + Value sliceTime = Torch::AtenSliceTensorOp::create( + rewriter, binder.getLoc(), flipType, sliceBatch, timeAxisVal, + cstZero, len, cstOne); // flip the sliced reversing tensor - Value dims = rewriter.create( - binder.getLoc(), + Value dims = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), rewriter.getType( rewriter.getType()), SmallVector{timeAxisVal}); - Value flip = rewriter.create( - binder.getLoc(), flipType, sliceTime, dims); + Value flip = Torch::AtenFlipOp::create(rewriter, binder.getLoc(), + flipType, sliceTime, dims); // embeds the reversed tensor to the input - Value embedTime = rewriter.create( - binder.getLoc(), sliceType, sliceBatch, flip, timeAxisVal, + Value embedTime = Torch::AtenSliceScatterOp::create( + rewriter, binder.getLoc(), sliceType, sliceBatch, flip, + timeAxisVal, /*start=*/cstZero, /*end=*/len, /*step=*/cstOne); - input = rewriter.create( - binder.getLoc(), resultType, input, embedTime, batchAxisVal, + input = Torch::AtenSliceScatterOp::create( + rewriter, binder.getLoc(), resultType, input, embedTime, + batchAxisVal, /*start=*/k, /*end=*/end, /*step=*/cstOne); } @@ -4050,25 +4112,25 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // step 2. Get dimension list of data. SmallVector dataDims; for (int64_t i = 0; i < dataRank; ++i) { - Value k = rewriter.create(loc, i); - Value dataDim = rewriter.create(loc, data, k); + Value k = Torch::ConstantIntOp::create(rewriter, loc, i); + Value dataDim = Torch::AtenSizeIntOp::create(rewriter, loc, data, k); dataDims.push_back(dataDim); } // step 3. Get dimension list of indices. - Value constZero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - Value constOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); + Value constZero = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(0)); + Value constOne = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(1)); SmallVector indicesDimsMinusOne; Value indicesFlattenDim = constOne; for (int64_t i = 0; i < indicesRank - 1; ++i) { - Value k = rewriter.create(loc, i); + Value k = Torch::ConstantIntOp::create(rewriter, loc, i); Value indicesDim = - rewriter.create(loc, indices, k); + Torch::AtenSizeIntOp::create(rewriter, loc, indices, k); indicesDimsMinusOne.push_back(indicesDim); - indicesFlattenDim = rewriter.create( - loc, indicesFlattenDim, indicesDim); + indicesFlattenDim = Torch::AtenMulIntOp::create( + rewriter, loc, indicesFlattenDim, indicesDim); } ArrayRef indicesShapeMinusOne = indicesShape.drop_back(); @@ -4094,8 +4156,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // step 4. Convert indices_shape[-1] dimensional indexing to 1D // indexing. - Value sliceDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(indicesRank - 1)); + Value sliceDim = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(indicesRank - 1)); SmallVector indicesSliceShape(indicesShapeMinusOne); indicesSliceShape.push_back(1); auto indicesSliceTy = rewriter.getType( @@ -4104,28 +4166,29 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value start = constZero; Value updatedIndices; for (int64_t i = 0; i < indicesLastDim; ++i) { - Value end = rewriter.create( - loc, rewriter.getI64IntegerAttr(i + 1)); - Value indicesSlice = rewriter.create( - loc, indicesSliceTy, indices, sliceDim, start, end, + Value end = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(i + 1)); + Value indicesSlice = Torch::AtenSliceTensorOp::create( + rewriter, loc, indicesSliceTy, indices, sliceDim, start, end, /*step=*/constOne); start = end; // Apply bounds checking on the indices slice. auto boolTy = rewriter.getType( indicesSliceShape, rewriter.getI1Type()); - Value lt = rewriter.create( - loc, boolTy, indicesSlice, constZero); - Value add = rewriter.create( - loc, indicesSliceTy, indicesSlice, dataDims[i], + Value lt = Torch::AtenLtScalarOp::create(rewriter, loc, boolTy, + indicesSlice, constZero); + Value add = Torch::AtenAddScalarOp::create( + rewriter, loc, indicesSliceTy, indicesSlice, dataDims[i], /*alpha=*/constOne); - indicesSlice = rewriter.create( - loc, indicesSliceTy, lt, add, indicesSlice); + indicesSlice = Torch::AtenWhereSelfOp::create( + rewriter, loc, indicesSliceTy, lt, add, indicesSlice); if (i == 0) { updatedIndices = indicesSlice; continue; } - updatedIndices = rewriter.create( - loc, indicesSliceTy, indicesSlice, updatedIndices, dataDims[i]); + updatedIndices = Torch::AtenAddTensorOp::create( + rewriter, loc, indicesSliceTy, indicesSlice, updatedIndices, + dataDims[i]); } // step 5. Compute all the required result types here. @@ -4188,13 +4251,13 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // data by inserting unit dimensions. auto intListTy = rewriter.getType( rewriter.getType()); - Value reshapeIndicesSizeList = - rewriter.create(loc, intListTy, - reshapeIndicesDims); + Value reshapeIndicesSizeList = Torch::PrimListConstructOp::create( + rewriter, loc, intListTy, reshapeIndicesDims); auto reshapeIndicesTy = rewriter.getType( reshapeIndicesShape, indicesTy.getOptionalDtype()); - Value reshapedIndices = rewriter.create( - loc, reshapeIndicesTy, updatedIndices, reshapeIndicesSizeList); + Value reshapedIndices = + Torch::AtenViewOp::create(rewriter, loc, reshapeIndicesTy, + updatedIndices, reshapeIndicesSizeList); // step 7. Flatten `q-1` dimensions of the indices and updates. auto flattenIndicesTy = rewriter.getType( @@ -4204,39 +4267,40 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value flattenedIndices = reshapedIndices; Value flattenedUpdates = updates; if (indicesRank == 1) { - flattenedIndices = rewriter.create( - loc, flattenIndicesTy, reshapedIndices, constZero); - flattenedUpdates = rewriter.create( - loc, flattenUpdatesTy, updates, constZero); + flattenedIndices = Torch::AtenUnsqueezeOp::create( + rewriter, loc, flattenIndicesTy, reshapedIndices, constZero); + flattenedUpdates = Torch::AtenUnsqueezeOp::create( + rewriter, loc, flattenUpdatesTy, updates, constZero); } else if (indicesRank > 1) { - Value endDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(indicesRank - 2)); - flattenedIndices = rewriter.create( - loc, flattenIndicesTy, reshapedIndices, constZero, endDim); - flattenedUpdates = rewriter.create( - loc, flattenUpdatesTy, updates, constZero, endDim); + Value endDim = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(indicesRank - 2)); + flattenedIndices = Torch::AtenFlattenUsingIntsOp::create( + rewriter, loc, flattenIndicesTy, reshapedIndices, constZero, + endDim); + flattenedUpdates = Torch::AtenFlattenUsingIntsOp::create( + rewriter, loc, flattenUpdatesTy, updates, constZero, endDim); } // step 8. Expand `r-indices_shape[-1]` dims of flattened indices. auto expandIndicesTy = rewriter.getType( expandIndicesShape, indicesTy.getOptionalDtype()); - Value expandIndicesSizeList = - rewriter.create(loc, intListTy, - expandIndicesDims); - Value constFalse = rewriter.create( - loc, rewriter.getType(), + Value expandIndicesSizeList = Torch::PrimListConstructOp::create( + rewriter, loc, intListTy, expandIndicesDims); + Value constFalse = Torch::ConstantBoolOp::create( + rewriter, loc, rewriter.getType(), rewriter.getBoolAttr(false)); - Value expandedIndices = rewriter.create( - loc, expandIndicesTy, flattenedIndices, expandIndicesSizeList, - /*implicit=*/constFalse); + Value expandedIndices = + Torch::AtenExpandOp::create(rewriter, loc, expandIndicesTy, + flattenedIndices, expandIndicesSizeList, + /*implicit=*/constFalse); // step 9. Flatten indices_shape[-1] dimensions of data. auto flattenDataTy = rewriter.getType( flattenDataShape, dataTy.getOptionalDtype()); - Value endDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(indicesLastDim - 1)); - Value flattenedData = rewriter.create( - loc, flattenDataTy, data, constZero, endDim); + Value endDim = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(indicesLastDim - 1)); + Value flattenedData = Torch::AtenFlattenUsingIntsOp::create( + rewriter, loc, flattenDataTy, data, constZero, endDim); // step 10. Now we have flattenedData, expandedIndices and // flattenedUpdates of same rank to perform scatter operation. @@ -4245,17 +4309,17 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value scatter; if (reduction == "none") { - scatter = rewriter.create( - loc, scatterTy, flattenedData, /*axis=*/constZero, + scatter = Torch::AtenScatterSrcOp::create( + rewriter, loc, scatterTy, flattenedData, /*axis=*/constZero, expandedIndices, flattenedUpdates); } else { Value cstReduction = - rewriter.create(loc, reduction); - Value constTrue = rewriter.create( - loc, rewriter.getType(), + Torch::ConstantStrOp::create(rewriter, loc, reduction); + Value constTrue = Torch::ConstantBoolOp::create( + rewriter, loc, rewriter.getType(), rewriter.getBoolAttr(true)); - scatter = rewriter.create( - loc, scatterTy, flattenedData, /*axis=*/constZero, + scatter = Torch::AtenScatterReduceTwoOp::create( + rewriter, loc, scatterTy, flattenedData, /*axis=*/constZero, expandedIndices, flattenedUpdates, cstReduction, /*include_self=*/constTrue); } @@ -4265,8 +4329,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( rewriter.replaceOp(binder.op, scatter); return success(); } - Value unflattenSizeList = rewriter.create( - loc, intListTy, dataDims); + Value unflattenSizeList = Torch::PrimListConstructOp::create( + rewriter, loc, intListTy, dataDims); rewriter.replaceOpWithNewOp( binder.op, resultType, scatter, constZero, unflattenSizeList); return success(); @@ -4305,8 +4369,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, "Not converting to AtenSplitToSequenceOp due to inputs "); - Value axisValue = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value axisValue = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getI64IntegerAttr(axis)); auto splitTy = cast(split.getType()); @@ -4326,8 +4390,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, "Dynamic shapes for Split is not yet supported"); } else if (splitSizes[0] <= 1) { // dealing with 1/0 element in 1-D tensor - Value splitInt = rewriter.create( - binder.getLoc(), rewriter.getType(), split); + Value splitInt = Torch::AtenItemOp::create( + rewriter, binder.getLoc(), rewriter.getType(), + split); rewriter.replaceOpWithNewOp( binder.op, resultType, self, splitInt, axisValue); return success(); @@ -4340,8 +4405,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return success(); } } else if (splitDim == 0) { // Handle 0-D tensor - Value splitInt = rewriter.create( - binder.getLoc(), rewriter.getType(), split); + Value splitInt = Torch::AtenItemOp::create( + rewriter, binder.getLoc(), rewriter.getType(), + split); rewriter.replaceOpWithNewOp( binder.op, resultType, self, splitInt, axisValue); return success(); @@ -4361,7 +4427,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.tensorResultTypes(resultTypes)) return failure(); - Value zero = rewriter.create(binder.getLoc(), 0); + Value zero = Torch::ConstantIntOp::create(rewriter, binder.getLoc(), 0); auto inputTy = cast(input.getType()); if (!inputTy.hasSizes()) { @@ -4378,18 +4444,18 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( if (axis < -1 * inputDim || axis > inputDim - 1) return rewriter.notifyMatchFailure(binder.op, "invalid value for axis"); - axisVal = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(axis)); + axisVal = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(axis)); axisWasNone = false; } else { axisVal = zero; axisWasNone = true; } - Value sortedVal = rewriter.create( - binder.getLoc(), rewriter.getBoolAttr(sorted)); + Value sortedVal = Torch::ConstantBoolOp::create( + rewriter, binder.getLoc(), rewriter.getBoolAttr(sorted)); Value trueVal = - rewriter.create(binder.getLoc(), true); + Torch::ConstantBoolOp::create(rewriter, binder.getLoc(), true); // The shape of inverse_indices is the same as input shape, but // resulTypes[2] must be used to avoid live value after conversion. @@ -4413,15 +4479,15 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( auto flattenResultTy = rewriter.getType( ArrayRef({inputNumel}), inputTy.getDtype()); Value negativeOne = - rewriter.create(binder.getLoc(), -1); - input = rewriter.create( - binder.getLoc(), flattenResultTy, input, zero, negativeOne); + Torch::ConstantIntOp::create(rewriter, binder.getLoc(), -1); + input = Torch::AtenFlattenUsingIntsOp::create( + rewriter, binder.getLoc(), flattenResultTy, input, zero, + negativeOne); } - Torch::AtenUniqueDimOp intermResults = - rewriter.create( - binder.getLoc(), outputTy, inverseTy, countsTy, input, axisVal, - sortedVal, trueVal, trueVal); + Torch::AtenUniqueDimOp intermResults = Torch::AtenUniqueDimOp::create( + rewriter, binder.getLoc(), outputTy, inverseTy, countsTy, input, + axisVal, sortedVal, trueVal, trueVal); SmallVector uniqueResults = intermResults.getResults(); @@ -4431,38 +4497,41 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( auto arangeResultType = rewriter.getType( ArrayRef({inputShape[0]}), countsTy.getOptionalDtype()); - Value inputDimZero = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(inputShape[0])); - Value int64Type = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(4)); - Value noneVal = rewriter.create(binder.getLoc()); + Value inputDimZero = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(inputShape[0])); + Value int64Type = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(4)); + Value noneVal = + Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); - Value perm = rewriter.create( - binder.getLoc(), arangeResultType, inputDimZero, + Value perm = Torch::AtenArangeOp::create( + rewriter, binder.getLoc(), arangeResultType, inputDimZero, /*dtype=*/int64Type, /*layout=*/noneVal, /*device=*/noneVal, /*pin_memory=*/noneVal); // Inverse has the same shape as input, but the dtype is not the same. Value flipDims = createConstantIntList(binder, rewriter, {0}); - Value inverse = rewriter.create( - binder.getLoc(), + Value inverse = Torch::AtenFlipOp::create( + rewriter, binder.getLoc(), inputTy.getWithSizesAndDtype(inputShape, countsTy.getDtype()), uniqueResults[1], flipDims); - perm = rewriter.create( - binder.getLoc(), cast(perm.getType()), perm, - flipDims); + perm = Torch::AtenFlipOp::create( + rewriter, binder.getLoc(), + cast(perm.getType()), perm, flipDims); auto newInverseTy = rewriter.getType( ArrayRef({outputTy.getSizes()[0]}), countsTy.getDtype()); Value newInverseSize = createConstantIntList(binder, rewriter, {outputTy.getSizes()[0]}); - Value newInverse = rewriter.create( - binder.getLoc(), newInverseTy, inverse, newInverseSize, + Value newInverse = Torch::AtenNewEmptyOp::create( + rewriter, binder.getLoc(), newInverseTy, inverse, newInverseSize, /*dtype=*/int64Type, /*layout=*/noneVal, /*device=*/noneVal, /*pin_memory=*/noneVal); - Value firstOccurIndices = rewriter.create( - binder.getLoc(), resultTypes[1], newInverse, zero, inverse, perm); + Value firstOccurIndices = Torch::AtenScatterSrcOp::create( + rewriter, binder.getLoc(), resultTypes[1], newInverse, zero, + inverse, perm); rewriter.replaceOp(binder.op, {uniqueResults[0], firstOccurIndices, uniqueResults[1], uniqueResults[2]}); @@ -4509,41 +4578,42 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, "input batch dimension cannot be dynamic"); int batch_size = (is_2d) ? inputShape[0] : 1; - Value none = rewriter.create(binder.getLoc()); - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value none = Torch::ConstantNoneOp::create(rewriter, binder.getLoc()); + Value zero = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); - Value one = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); - Value cstFalse = rewriter.create( - binder.getLoc(), rewriter.getBoolAttr(false)); + Value one = Torch::ConstantIntOp::create(rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(1)); + Value cstFalse = Torch::ConstantBoolOp::create( + rewriter, binder.getLoc(), rewriter.getBoolAttr(false)); auto intType = rewriter.getType(); - Value loopConditionTrue = rewriter.create( - binder.getLoc(), rewriter.getBoolAttr(true)); + Value loopConditionTrue = Torch::ConstantBoolOp::create( + rewriter, binder.getLoc(), rewriter.getBoolAttr(true)); Type loopIndexType = intType; // create a zero tensor for output SmallVector resultShape(resultType.getSizes()); int64_t rank = resultShape.size(); SmallVector zerosShapeValues; for (int j = 0; j < rank; j++) { - Value dimSize = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(resultShape[j])); + Value dimSize = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(resultShape[j])); zerosShapeValues.push_back(dimSize); } - Value zerosShapeList = rewriter.create( - binder.getLoc(), + Value zerosShapeList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), rewriter.getType( rewriter.getType()), zerosShapeValues); - Value output = rewriter.create( - binder.getLoc(), resultType, zerosShapeList, none, none, none, - none); - - Value batchSize = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(batch_size)); - auto batchLoop = rewriter.create( - binder.getLoc(), TypeRange({output.getType()}), batchSize, + Value output = + Torch::AtenZerosOp::create(rewriter, binder.getLoc(), resultType, + zerosShapeList, none, none, none, none); + + Value batchSize = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(batch_size)); + auto batchLoop = Torch::PrimLoopOp::create( + rewriter, binder.getLoc(), TypeRange({output.getType()}), batchSize, loopConditionTrue, ValueRange({output})); { PatternRewriter::InsertionGuard guard(rewriter); @@ -4563,13 +4633,13 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( inputSequenceShape.push_back(inputShape[1]); auto inputSequenceType = rewriter.getType( inputSequenceShape, inputType.getOptionalDtype()); - Value batchPlusOne = rewriter.create( - binder.getLoc(), batchValue, one); - inputSequence = rewriter.create( - binder.getLoc(), inputSequenceType, input, /*dim=*/zero, - batchValue, batchPlusOne, one); - inputSequence = rewriter.create( - binder.getLoc(), + Value batchPlusOne = Torch::AtenAddIntOp::create( + rewriter, binder.getLoc(), batchValue, one); + inputSequence = Torch::AtenSliceTensorOp::create( + rewriter, binder.getLoc(), inputSequenceType, input, + /*dim=*/zero, batchValue, batchPlusOne, one); + inputSequence = Torch::AtenSqueezeDimOp::create( + rewriter, binder.getLoc(), Torch::ValueTensorType::get(binder.op->getContext(), ArrayRef{inputShape[1]}, inputType.getOptionalDtype()), @@ -4580,11 +4650,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( outputForBatchShape.push_back(resultShape[1]); auto outputForBatchType = rewriter.getType( outputForBatchShape, resultType.getOptionalDtype()); - outputForBatch = rewriter.create( - binder.getLoc(), outputForBatchType, output, + outputForBatch = Torch::AtenSliceTensorOp::create( + rewriter, binder.getLoc(), outputForBatchType, output, /*dim=*/zero, batchValue, batchPlusOne, one); - outputForBatch = rewriter.create( - binder.getLoc(), + outputForBatch = Torch::AtenSqueezeDimOp::create( + rewriter, binder.getLoc(), Torch::ValueTensorType::get(binder.op->getContext(), ArrayRef{resultShape[1]}, resultType.getOptionalDtype()), @@ -4609,27 +4679,28 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( continue; } - Value ngramLength = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(ngram_length)); + Value ngramLength = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr(ngram_length)); for (int start = start_idx; start < end_idx; start += ngram_length, ngram_i++) { - Value count = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value count = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(0)); // for 1-grams, there is no skipping (skip = gap between // consecutive values in the n-gram pulled from the input // sequence), so we default to skip_count_bound = 1 in that case // to avoid repeating the same count multiple times. int skip_count_bound = (ngram_length == 1) ? 1 : (max_skip_count + 1); - Value skipCountBound = rewriter.create( - binder.getLoc(), intType, + Value skipCountBound = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), intType, rewriter.getI64IntegerAttr(skip_count_bound)); // given a n-gram to search for, and the input sequence to search // in, we need to count how many times that n-gram appears in the // input for each skip between 0 and max_skip_count (inclusive). - auto skipLoop = rewriter.create( - binder.getLoc(), TypeRange({count.getType()}), skipCountBound, - loopConditionTrue, ValueRange({count})); + auto skipLoop = Torch::PrimLoopOp::create( + rewriter, binder.getLoc(), TypeRange({count.getType()}), + skipCountBound, loopConditionTrue, ValueRange({count})); { PatternRewriter::InsertionGuard guard(rewriter); Block *skipLoopBody = rewriter.createBlock( @@ -4637,30 +4708,30 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( TypeRange({loopIndexType, count.getType()}), {binder.getLoc(), binder.getLoc()}); Value skipCount = skipLoopBody->getArgument(0); - Value skipCountPlusOne = rewriter.create( - binder.getLoc(), skipCount, one); + Value skipCountPlusOne = Torch::AtenAddIntOp::create( + rewriter, binder.getLoc(), skipCount, one); count = skipLoopBody->getArgument(1); // max_start_index = // inputSizes.back() - ((ngram_length - 1) * (skip_count + 1)); // the index one higher than the last possible start index // without the input ngram going out of bounds - Value seqLen = rewriter.create( - binder.getLoc(), intType, + Value seqLen = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), intType, rewriter.getI64IntegerAttr(inputSizes.back())); - Value ngramLengthMinusOne = - rewriter.create(binder.getLoc(), - ngramLength, one); - Value ngramSkipLength = rewriter.create( - binder.getLoc(), ngramLengthMinusOne, skipCountPlusOne); - Value maxStartIndex = rewriter.create( - binder.getLoc(), seqLen, ngramSkipLength); + Value ngramLengthMinusOne = Torch::AtenSubIntOp::create( + rewriter, binder.getLoc(), ngramLength, one); + Value ngramSkipLength = Torch::AtenMulIntOp::create( + rewriter, binder.getLoc(), ngramLengthMinusOne, + skipCountPlusOne); + Value maxStartIndex = Torch::AtenSubIntOp::create( + rewriter, binder.getLoc(), seqLen, ngramSkipLength); // This loop will extract each n-gram with the given skip_count // from the input sequence from start input index, and increment // the count if the n-gram matches the one gotten from the // pool_int64s - auto countLoop = rewriter.create( - binder.getLoc(), TypeRange({count.getType()}), + auto countLoop = Torch::PrimLoopOp::create( + rewriter, binder.getLoc(), TypeRange({count.getType()}), maxStartIndex, loopConditionTrue, ValueRange({count})); { PatternRewriter::InsertionGuard guard(rewriter); @@ -4681,53 +4752,57 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( inputSequenceType.getWithSizesAndDtype( llvm::ArrayRef(selectSizes), inputSequenceType.getOptionalDtype()); - Value foundNgram = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value foundNgram = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(1)); for (int i = 0; i < ngram_length; i++) { - Value selectIndex = rewriter.create( - binder.getLoc(), rewriter.getType(), + Value selectIndex = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - selectIndex = rewriter.create( - binder.getLoc(), selectIndex, skipCountPlusOne); - selectIndex = rewriter.create( - binder.getLoc(), selectIndex, startInputIdx); - Value inputExtract = - rewriter.create( - binder.getLoc(), selectResultType, inputSequence, - zero, selectIndex); - Value inputNgram_i = rewriter.create( - binder.getLoc(), rewriter.getType(), - inputExtract); - - Value poolNgram_i = rewriter.create( - binder.getLoc(), + selectIndex = Torch::AtenMulIntOp::create( + rewriter, binder.getLoc(), selectIndex, + skipCountPlusOne); + selectIndex = Torch::AtenAddIntOp::create( + rewriter, binder.getLoc(), selectIndex, startInputIdx); + Value inputExtract = Torch::AtenSelectIntOp::create( + rewriter, binder.getLoc(), selectResultType, + inputSequence, zero, selectIndex); + Value inputNgram_i = Torch::AtenItemOp::create( + rewriter, binder.getLoc(), + rewriter.getType(), inputExtract); + + Value poolNgram_i = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(pool_int64s[start + i])); - Value isEqual = rewriter.create( - binder.getLoc(), inputNgram_i, poolNgram_i); - isEqual = rewriter.create( - binder.getLoc(), isEqual); - foundNgram = rewriter.create( - binder.getLoc(), isEqual, foundNgram); + Value isEqual = Torch::AtenEqIntOp::create( + rewriter, binder.getLoc(), inputNgram_i, poolNgram_i); + isEqual = Torch::AtenIntBoolOp::create( + rewriter, binder.getLoc(), isEqual); + foundNgram = Torch::AtenMulIntOp::create( + rewriter, binder.getLoc(), isEqual, foundNgram); } - count = rewriter.create( - binder.getLoc(), count, foundNgram); - rewriter.create( - binder.getLoc(), loopConditionTrue, ValueRange({count})); + count = Torch::AtenAddIntOp::create(rewriter, binder.getLoc(), + count, foundNgram); + Torch::PrimLoopConditionOp::create(rewriter, binder.getLoc(), + loopConditionTrue, + ValueRange({count})); } count = countLoop.getResult(0); - rewriter.create( - binder.getLoc(), loopConditionTrue, ValueRange({count})); + Torch::PrimLoopConditionOp::create(rewriter, binder.getLoc(), + loopConditionTrue, + ValueRange({count})); } count = skipLoop.getResult(0); - Value countFloat = rewriter.create( - binder.getLoc(), count); + Value countFloat = Torch::AtenFloatScalarOp::create( + rewriter, binder.getLoc(), count); if (mode == "IDF" || mode == "TFIDF") { // both IDF and TFIDF modes use weights float weight = weights[ngram_i]; - Value constWeight = rewriter.create( - binder.getLoc(), rewriter.getF64FloatAttr(weight)); + Value constWeight = Torch::ConstantFloatOp::create( + rewriter, binder.getLoc(), + rewriter.getF64FloatAttr(weight)); // TFIDF Value multiplier = countFloat; @@ -4736,64 +4811,66 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // and the i-th element in weights would be used to scale // (by multiplication) the count of the i-th n-gram in pool. - Value intCount = rewriter.create( - binder.getLoc(), count); + Value intCount = Torch::AtenIntScalarOp::create( + rewriter, binder.getLoc(), count); // compare intCount > 0 - Value gtZeroCount = rewriter.create( - binder.getLoc(), intCount, zero); - gtZeroCount = rewriter.create( - binder.getLoc(), gtZeroCount); - Value gtZeroCountFloat = - rewriter.create(binder.getLoc(), - gtZeroCount); + Value gtZeroCount = Torch::AtenGtIntOp::create( + rewriter, binder.getLoc(), intCount, zero); + gtZeroCount = Torch::AtenIntBoolOp::create( + rewriter, binder.getLoc(), gtZeroCount); + Value gtZeroCountFloat = Torch::AtenFloatScalarOp::create( + rewriter, binder.getLoc(), gtZeroCount); multiplier = gtZeroCountFloat; } - countFloat = rewriter.create( - binder.getLoc(), multiplier, constWeight); + countFloat = Torch::AtenMulFloatOp::create( + rewriter, binder.getLoc(), multiplier, constWeight); } - Value dataList = rewriter.create( - binder.getLoc(), + Value dataList = Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), rewriter.getType( rewriter.getType()), SmallVector{countFloat}); - Value cstDtype = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr( - (int)torch_upstream::ScalarType::Float)); + Value cstDtype = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), + rewriter.getI64IntegerAttr( + (int)torch_upstream::ScalarType::Float)); SmallVector countShape{1}; auto countType = rewriter.getType( countShape, resultType.getOptionalDtype()); - Value countTensor = rewriter.create( - binder.getLoc(), countType, dataList, /*dtype=*/cstDtype, + Value countTensor = Torch::AtenTensorOp::create( + rewriter, binder.getLoc(), countType, dataList, + /*dtype=*/cstDtype, /*layout=*/none, /*requires_grad=*/cstFalse); - Value insertStart = rewriter.create( - binder.getLoc(), + Value insertStart = Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(ngram_indexes[ngram_i])); - Value insertEnd = rewriter.create( - binder.getLoc(), insertStart, one); - outputForBatch = rewriter.create( - binder.getLoc(), outputForBatch.getType(), outputForBatch, - countTensor, + Value insertEnd = Torch::AtenAddIntOp::create( + rewriter, binder.getLoc(), insertStart, one); + outputForBatch = Torch::AtenSliceScatterOp::create( + rewriter, binder.getLoc(), outputForBatch.getType(), + outputForBatch, countTensor, /*dim=*/zero, insertStart, insertEnd, /*step=*/one); } // start } if (is_2d) { - Value batchPlusOne = rewriter.create( - binder.getLoc(), batchValue, one); - outputForBatch = rewriter.create( - binder.getLoc(), + Value batchPlusOne = Torch::AtenAddIntOp::create( + rewriter, binder.getLoc(), batchValue, one); + outputForBatch = Torch::AtenUnsqueezeOp::create( + rewriter, binder.getLoc(), rewriter.getType( llvm::SmallVector{1, resultShape[1]}, resultType.getDtype()), outputForBatch, zero); - output = rewriter.create( - binder.getLoc(), resultType, output, outputForBatch, + output = Torch::AtenSliceScatterOp::create( + rewriter, binder.getLoc(), resultType, output, outputForBatch, /*dim=*/zero, batchValue, batchPlusOne, /*step=*/one); } else { output = outputForBatch; } - rewriter.create( - binder.getLoc(), loopConditionTrue, ValueRange({output})); + Torch::PrimLoopConditionOp::create(rewriter, binder.getLoc(), + loopConditionTrue, + ValueRange({output})); } output = batchLoop.getResult(0); rewriter.replaceOp(binder.op, output); @@ -4830,10 +4907,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( "Expects at least one scan input"); } - Value constZero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - Value constOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); + Value constZero = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(0)); + Value constOne = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(1)); SmallVector scanOutTypes; for (unsigned i = numInits; i < resultTypes.size(); i++) { auto scanOutTy = cast(resultTypes[i]); @@ -4844,12 +4921,12 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( scanOutTypes.push_back(resultTypes[i]); } // Create torch.prim.Loop op. - Value maxTripCount = rewriter.create( - loc, scanInputs[0], constZero); - auto constBoolTrue = rewriter.create( - binder.getLoc(), rewriter.getBoolAttr(true)); - auto primLoop = rewriter.create( - loc, resultTypes, maxTripCount, constBoolTrue, initVals); + Value maxTripCount = Torch::AtenSizeIntOp::create( + rewriter, loc, scanInputs[0], constZero); + auto constBoolTrue = Torch::ConstantBoolOp::create( + rewriter, binder.getLoc(), rewriter.getBoolAttr(true)); + auto primLoop = Torch::PrimLoopOp::create( + rewriter, loc, resultTypes, maxTripCount, constBoolTrue, initVals); rewriter.cloneRegionBefore(*loopBodyIn, primLoop.getRegion(), primLoop.getRegion().begin()); @@ -4865,8 +4942,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( for (unsigned i = 0; i < numScanInputs; i++) { auto loopBlockArg = primLoop.getRegion().getArgument(numInits + 1 + i); - Value extract = rewriter.create( - loc, loopBlockArg.getType(), scanInputs[i], constZero, loopInd); + Value extract = Torch::AtenSelectIntOp::create( + rewriter, loc, loopBlockArg.getType(), scanInputs[i], constZero, + loopInd); loopBlockArg.replaceAllUsesWith(extract); } primLoop.getRegion().front().eraseArguments(numInits + 1, @@ -4892,8 +4970,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( rewriter, binder.op, scanOutSlices[i], constZero); if (failed(src)) return failure(); - Value scanOut = rewriter.create( - loc, scanOutTypes[i], self, src.value(), constZero, + Value scanOut = Torch::AtenSliceScatterOp::create( + rewriter, loc, scanOutTypes[i], self, src.value(), constZero, /*start=*/loopInd, /*end=*/loopInd, constOne); resTerminatorOperands.push_back(scanOut); diff --git a/lib/Conversion/TorchOnnxToTorch/OnnxRecurrentLayerOpExpanders.cpp b/lib/Conversion/TorchOnnxToTorch/OnnxRecurrentLayerOpExpanders.cpp index 472eaa15984c..27e72c842b8a 100644 --- a/lib/Conversion/TorchOnnxToTorch/OnnxRecurrentLayerOpExpanders.cpp +++ b/lib/Conversion/TorchOnnxToTorch/OnnxRecurrentLayerOpExpanders.cpp @@ -26,10 +26,10 @@ Value getDirection(ImplicitLocOpBuilder b, int64_t direction, Value input) { llvm::SmallVector{inputType.getSizes().drop_front()}, inputType.getDtype())); auto intType = b.getType(); - Value selectDim = b.create(intType, b.getI64IntegerAttr(0)); + Value selectDim = ConstantIntOp::create(b, intType, b.getI64IntegerAttr(0)); Value cstDirection = - b.create(intType, b.getI64IntegerAttr(direction)); - return b.create(outputType, input, selectDim, cstDirection); + ConstantIntOp::create(b, intType, b.getI64IntegerAttr(direction)); + return AtenSelectIntOp::create(b, outputType, input, selectDim, cstDirection); } struct RnnWeights { @@ -48,11 +48,11 @@ Value rnn_cell(ImplicitLocOpBuilder &b, Value Xt, Value H_prev, auto hTy = cast(H_prev.getType()); auto intType = b.getType(); - Value cstOne = b.create(intType, b.getI64IntegerAttr(1)); + Value cstOne = ConstantIntOp::create(b, intType, b.getI64IntegerAttr(1)); - Value i_x = b.create(hTy, Xt, weights.Wi, weights.Wbi); - Value i_h = b.create(hTy, H_prev, weights.Ri, weights.Rbi); - Value i = b.create(hTy, i_x, i_h, cstOne); + Value i_x = AtenLinearOp::create(b, hTy, Xt, weights.Wi, weights.Wbi); + Value i_h = AtenLinearOp::create(b, hTy, H_prev, weights.Ri, weights.Rbi); + Value i = AtenAddTensorOp::create(b, hTy, i_x, i_h, cstOne); Value H_new = createActivationByName(b, activations.f, i); return H_new; @@ -76,39 +76,39 @@ RnnLayerOutput rnn_layer(ImplicitLocOpBuilder &b, Value X, Value initial_h, auto intType = b.getType(); - Value cstNone = b.create(); - Value cstZero = b.create(intType, b.getI64IntegerAttr(0)); - Value cstOne = b.create(intType, b.getI64IntegerAttr(1)); + Value cstNone = ConstantNoneOp::create(b); + Value cstZero = ConstantIntOp::create(b, intType, b.getI64IntegerAttr(0)); + Value cstOne = ConstantIntOp::create(b, intType, b.getI64IntegerAttr(1)); Value cstSeqLen = - b.create(intType, b.getI64IntegerAttr(seq_len)); + ConstantIntOp::create(b, intType, b.getI64IntegerAttr(seq_len)); Value cstBatchSize = - b.create(intType, b.getI64IntegerAttr(batch_size)); + ConstantIntOp::create(b, intType, b.getI64IntegerAttr(batch_size)); Value cstHiddenSize = - b.create(intType, b.getI64IntegerAttr(hidden_size)); + ConstantIntOp::create(b, intType, b.getI64IntegerAttr(hidden_size)); auto yTy = b.getType( SmallVector{seq_len, batch_size, hidden_size}, hTy.getDtype()); - auto YShapeList = b.create( - b.getType(intType), + auto YShapeList = PrimListConstructOp::create( + b, b.getType(intType), ValueRange({cstSeqLen, cstBatchSize, cstHiddenSize})); int64_t hDtypeInt = static_cast(getScalarTypeForType(hTy.getDtype())); Value hDtypeIntVal = - b.create(loc, b.getI64IntegerAttr(hDtypeInt)); + ConstantIntOp::create(b, loc, b.getI64IntegerAttr(hDtypeInt)); - Value Y_initial = b.create(yTy, YShapeList, hDtypeIntVal, - cstNone, cstNone, cstNone); + Value Y_initial = AtenZerosOp::create(b, yTy, YShapeList, hDtypeIntVal, + cstNone, cstNone, cstNone); Value maxTripCount = - b.create(intType, b.getI64IntegerAttr(seq_len)); - Value loopConditionTrue = b.create(true); + ConstantIntOp::create(b, intType, b.getI64IntegerAttr(seq_len)); + Value loopConditionTrue = ConstantBoolOp::create(b, true); Type loopIndexType = intType; - auto loop = b.create(TypeRange({yTy, hTy}), maxTripCount, - loopConditionTrue, - ValueRange({Y_initial, initial_h})); + auto loop = + PrimLoopOp::create(b, TypeRange({yTy, hTy}), maxTripCount, + loopConditionTrue, ValueRange({Y_initial, initial_h})); { OpBuilder::InsertionGuard guard(b); Block *loopBody = @@ -129,22 +129,22 @@ RnnLayerOutput rnn_layer(ImplicitLocOpBuilder &b, Value X, Value initial_h, auto XtType = b.getType( llvm::SmallVector{batch_size, input_size}, xTy.getDtype()); - Value Xt = b.create(XtType, X, cstZero, loopIndex); + Value Xt = AtenSelectIntOp::create(b, XtType, X, cstZero, loopIndex); Value H_new = rnn_cell(b, Xt, H_prev, weights, activations); Type hTyUnsqueezed = b.getType( llvm::SmallVector{1, batch_size, hidden_size}, hTy.getDtype()); Value H_new_unsqueezed = - b.create(hTyUnsqueezed, H_new, cstZero); + AtenUnsqueezeOp::create(b, hTyUnsqueezed, H_new, cstZero); - auto loopIndexPlusOne = b.create(intType, loopIndex, cstOne); + auto loopIndexPlusOne = AtenAddIntOp::create(b, intType, loopIndex, cstOne); Value Y_new = - b.create(yTy, Y_prev, H_new_unsqueezed, cstZero, - loopIndex, loopIndexPlusOne, cstOne); + AtenSliceScatterOp::create(b, yTy, Y_prev, H_new_unsqueezed, cstZero, + loopIndex, loopIndexPlusOne, cstOne); - b.create(loopConditionTrue, - ValueRange({Y_new, H_new})); + PrimLoopConditionOp::create(b, loopConditionTrue, + ValueRange({Y_new, H_new})); } RnnLayerOutput output; output.Y = loop.getResult(0); @@ -161,10 +161,10 @@ static Value StaticTranspose(ImplicitLocOpBuilder b, Value value, int64_t dim0, valueTy = b.getType(valueShape, valueTy.getDtype()); auto intType = b.getType(); - Value dim0v = b.create(intType, b.getI64IntegerAttr(dim0)); - Value dim1v = b.create(intType, b.getI64IntegerAttr(dim1)); + Value dim0v = ConstantIntOp::create(b, intType, b.getI64IntegerAttr(dim0)); + Value dim1v = ConstantIntOp::create(b, intType, b.getI64IntegerAttr(dim1)); - return b.create(valueTy, value, dim0v, dim1v); + return AtenTransposeIntOp::create(b, valueTy, value, dim0v, dim1v); } LogicalResult OnnxRnnExpander(OpBinder binder, @@ -173,9 +173,9 @@ LogicalResult OnnxRnnExpander(OpBinder binder, mlir::ImplicitLocOpBuilder b(loc, rewriter); auto intType = b.getType(); - Value cstNone = b.create(); - Value cstZero = b.create(intType, b.getI64IntegerAttr(0)); - Value cstOne = b.create(intType, b.getI64IntegerAttr(1)); + Value cstNone = ConstantNoneOp::create(b); + Value cstZero = ConstantIntOp::create(b, intType, b.getI64IntegerAttr(0)); + Value cstOne = ConstantIntOp::create(b, intType, b.getI64IntegerAttr(1)); int64_t num_directions = Torch::kUnknownSize; int64_t hidden_size = Torch::kUnknownSize; @@ -346,28 +346,29 @@ LogicalResult OnnxRnnExpander(OpBinder binder, if (B == nullptr) { SmallVector BShape = {num_directions, 2 * hidden_size}; SmallVector BShapeListContents = { - b.create(intType, b.getI64IntegerAttr(num_directions)), - b.create(intType, b.getI64IntegerAttr(2 * hidden_size))}; - Value BShapeList = b.create( - b.getType(intType), BShapeListContents); + ConstantIntOp::create(b, intType, b.getI64IntegerAttr(num_directions)), + ConstantIntOp::create(b, intType, + b.getI64IntegerAttr(2 * hidden_size))}; + Value BShapeList = PrimListConstructOp::create( + b, b.getType(intType), BShapeListContents); auto BType = b.getType(BShape, wTy.getDtype()); - B = b.create(BType, BShapeList, cstXDtype, cstNone, - cstNone, cstNone); + B = Torch::AtenZerosOp::create(b, BType, BShapeList, cstXDtype, cstNone, + cstNone, cstNone); } if (initial_h == nullptr) { SmallVector initial_h_shape = {num_directions, batch_size, hidden_size}; SmallVector initial_h_shape_list_contents = { - b.create(intType, b.getI64IntegerAttr(num_directions)), - b.create(intType, b.getI64IntegerAttr(batch_size)), - b.create(intType, b.getI64IntegerAttr(hidden_size))}; - Value initial_h_shape_list = b.create( - b.getType(intType), initial_h_shape_list_contents); + ConstantIntOp::create(b, intType, b.getI64IntegerAttr(num_directions)), + ConstantIntOp::create(b, intType, b.getI64IntegerAttr(batch_size)), + ConstantIntOp::create(b, intType, b.getI64IntegerAttr(hidden_size))}; + Value initial_h_shape_list = PrimListConstructOp::create( + b, b.getType(intType), initial_h_shape_list_contents); auto initial_h_type = b.getType(initial_h_shape, wTy.getDtype()); initial_h = - b.create(initial_h_type, initial_h_shape_list, - cstXDtype, cstNone, cstNone, cstNone); + Torch::AtenZerosOp::create(b, initial_h_type, initial_h_shape_list, + cstXDtype, cstNone, cstNone, cstNone); } Value W_forward = getDirection(b, 0, W); @@ -376,22 +377,24 @@ LogicalResult OnnxRnnExpander(OpBinder binder, Value initial_h_forward = getDirection(b, 0, initial_h); Value cstHiddenSize = - b.create(intType, b.getI64IntegerAttr(hidden_size)); + ConstantIntOp::create(b, intType, b.getI64IntegerAttr(hidden_size)); RnnWeights weights; weights.Wi = W_forward; weights.Ri = R_forward; - weights.Wbi = b.create( + weights.Wbi = AtenSliceTensorOp::create( + b, b.getType(llvm::SmallVector{hidden_size}, wTy.getDtype()), B_forward, cstZero, cstZero, cstHiddenSize, cstOne); - weights.Rbi = b.create( + weights.Rbi = AtenSliceTensorOp::create( + b, b.getType(llvm::SmallVector{hidden_size}, wTy.getDtype()), B_forward, cstZero, cstHiddenSize, - b.create( - cstHiddenSize, - b.create(intType, b.getI64IntegerAttr(2))), + AtenMulIntOp::create( + b, cstHiddenSize, + ConstantIntOp::create(b, intType, b.getI64IntegerAttr(2))), cstOne); RnnLayerOutput rnnLayerOutput = @@ -400,15 +403,15 @@ LogicalResult OnnxRnnExpander(OpBinder binder, auto Y_h_unsqueezed_type = b.getType( llvm::SmallVector{num_directions, batch_size, hidden_size}, cast(rnnLayerOutput.Y_h.getType()).getDtype()); - Value Y_h_unsqueezed = b.create(Y_h_unsqueezed_type, - rnnLayerOutput.Y_h, cstZero); + Value Y_h_unsqueezed = AtenUnsqueezeOp::create(b, Y_h_unsqueezed_type, + rnnLayerOutput.Y_h, cstZero); auto Y_unsqueezed_type = b.getType( llvm::SmallVector{seq_len, num_directions, batch_size, hidden_size}, cast(rnnLayerOutput.Y_h.getType()).getDtype()); Value Y_unsqueezed = - b.create(Y_unsqueezed_type, rnnLayerOutput.Y, cstOne); + AtenUnsqueezeOp::create(b, Y_unsqueezed_type, rnnLayerOutput.Y, cstOne); if (layout == 1) { Y_h_unsqueezed = StaticTranspose(b, Y_h_unsqueezed, 0, 1); @@ -466,37 +469,37 @@ LstmCellState lstm_cell(ImplicitLocOpBuilder &b, Value Xt, Value H_prev, auto intType = b.getType(); auto hTy = cast(H_prev.getType()); - Value cstOne = b.create(intType, b.getI64IntegerAttr(1)); + Value cstOne = ConstantIntOp::create(b, intType, b.getI64IntegerAttr(1)); // Apply linear/matmul for each gate separately // names are consistent with ONNX LSTM documentation - Value i_x = b.create(hTy, Xt, weights.W_i, weights.Wb_i); - Value i_h = b.create(hTy, H_prev, weights.R_i, weights.Rb_i); - Value i = b.create(hTy, i_x, i_h, cstOne); + Value i_x = AtenLinearOp::create(b, hTy, Xt, weights.W_i, weights.Wb_i); + Value i_h = AtenLinearOp::create(b, hTy, H_prev, weights.R_i, weights.Rb_i); + Value i = AtenAddTensorOp::create(b, hTy, i_x, i_h, cstOne); Value i_act = createActivationByName(b, activations.f, i); - Value o_x = b.create(hTy, Xt, weights.W_o, weights.Wb_o); - Value o_h = b.create(hTy, H_prev, weights.R_o, weights.Rb_o); - Value o = b.create(hTy, o_x, o_h, cstOne); + Value o_x = AtenLinearOp::create(b, hTy, Xt, weights.W_o, weights.Wb_o); + Value o_h = AtenLinearOp::create(b, hTy, H_prev, weights.R_o, weights.Rb_o); + Value o = AtenAddTensorOp::create(b, hTy, o_x, o_h, cstOne); Value o_act = createActivationByName(b, activations.f, o); - Value f_x = b.create(hTy, Xt, weights.W_f, weights.Wb_f); - Value f_h = b.create(hTy, H_prev, weights.R_f, weights.Rb_f); - Value f = b.create(hTy, f_x, f_h, cstOne); + Value f_x = AtenLinearOp::create(b, hTy, Xt, weights.W_f, weights.Wb_f); + Value f_h = AtenLinearOp::create(b, hTy, H_prev, weights.R_f, weights.Rb_f); + Value f = AtenAddTensorOp::create(b, hTy, f_x, f_h, cstOne); Value f_act = createActivationByName(b, activations.f, f); - Value ct_x = b.create(hTy, Xt, weights.W_c, weights.Wb_c); - Value ct_h = b.create(hTy, H_prev, weights.R_c, weights.Rb_c); - Value ct = b.create(hTy, ct_x, ct_h, cstOne); + Value ct_x = AtenLinearOp::create(b, hTy, Xt, weights.W_c, weights.Wb_c); + Value ct_h = AtenLinearOp::create(b, hTy, H_prev, weights.R_c, weights.Rb_c); + Value ct = AtenAddTensorOp::create(b, hTy, ct_x, ct_h, cstOne); Value ct_act = createActivationByName(b, activations.g, ct); - Value C_forget = b.create(hTy, f_act, C_prev); - Value C_input = b.create(hTy, i_act, ct_act); + Value C_forget = AtenMulTensorOp::create(b, hTy, f_act, C_prev); + Value C_input = AtenMulTensorOp::create(b, hTy, i_act, ct_act); LstmCellState newCellState; - newCellState.C = b.create(hTy, C_forget, C_input, cstOne); + newCellState.C = AtenAddTensorOp::create(b, hTy, C_forget, C_input, cstOne); Value C_new_act = createActivationByName(b, activations.h, newCellState.C); - newCellState.H = b.create(hTy, o_act, C_new_act); + newCellState.H = AtenMulTensorOp::create(b, hTy, o_act, C_new_act); return newCellState; } @@ -533,34 +536,34 @@ LstmLayerOutput lstm_layer(ConversionPatternRewriter &rewriter, Location &loc, auto intType = b.getType(); - Value cstNone = b.create(); - Value cstZero = b.create(intType, b.getI64IntegerAttr(0)); - Value cstOne = b.create(intType, b.getI64IntegerAttr(1)); + Value cstNone = ConstantNoneOp::create(b); + Value cstZero = ConstantIntOp::create(b, intType, b.getI64IntegerAttr(0)); + Value cstOne = ConstantIntOp::create(b, intType, b.getI64IntegerAttr(1)); Value cstHiddenSize = - b.create(intType, b.getI64IntegerAttr(hidden_size)); + ConstantIntOp::create(b, intType, b.getI64IntegerAttr(hidden_size)); auto yTy = getTensorTypeFromShapeValues({seq_len, batch_size, cstHiddenSize}, hTy.getDtype()); - auto YShapeList = b.create( - b.getType(intType), + auto YShapeList = PrimListConstructOp::create( + b, b.getType(intType), ValueRange({seq_len, batch_size, cstHiddenSize})); int64_t hDtypeInt = static_cast(getScalarTypeForType(hTy.getDtype())); Value hDtypeIntVal = - b.create(loc, b.getI64IntegerAttr(hDtypeInt)); + ConstantIntOp::create(b, loc, b.getI64IntegerAttr(hDtypeInt)); - Value Y_initial = b.create(yTy, YShapeList, hDtypeIntVal, - cstNone, cstNone, cstNone); + Value Y_initial = AtenZerosOp::create(b, yTy, YShapeList, hDtypeIntVal, + cstNone, cstNone, cstNone); // Create a for-like PrimLoopOp. Value maxTripCount = seq_len; - Value loopConditionTrue = b.create(true); + Value loopConditionTrue = ConstantBoolOp::create(b, true); Type loopIndexType = intType; - auto loop = b.create( - TypeRange({yTy, hTy, cTy}), maxTripCount, loopConditionTrue, - ValueRange({Y_initial, initial_h, initial_c})); + auto loop = PrimLoopOp::create(b, TypeRange({yTy, hTy, cTy}), maxTripCount, + loopConditionTrue, + ValueRange({Y_initial, initial_h, initial_c})); { OpBuilder::InsertionGuard guard(b); Block *loopBody = @@ -583,7 +586,7 @@ LstmLayerOutput lstm_layer(ConversionPatternRewriter &rewriter, Location &loc, auto XtType = getTensorTypeFromShapeValues({batch_size, input_size}, xTy.getDtype()); - Value Xt = b.create(XtType, X, cstZero, loopIndex); + Value Xt = AtenSelectIntOp::create(b, XtType, X, cstZero, loopIndex); auto [H_new, C_new] = lstm_cell(b, Xt, H_prev, C_prev, weights, activations); @@ -591,15 +594,15 @@ LstmLayerOutput lstm_layer(ConversionPatternRewriter &rewriter, Location &loc, auto hTyUnsqueezed = getTensorTypeFromShapeValues( {cstOne, batch_size, cstHiddenSize}, hTy.getDtype()); Value H_new_unsqueezed = - b.create(hTyUnsqueezed, H_new, cstZero); + AtenUnsqueezeOp::create(b, hTyUnsqueezed, H_new, cstZero); - auto loopIndexPlusOne = b.create(intType, loopIndex, cstOne); + auto loopIndexPlusOne = AtenAddIntOp::create(b, intType, loopIndex, cstOne); Value Y_new = - b.create(yTy, Y_prev, H_new_unsqueezed, cstZero, - loopIndex, loopIndexPlusOne, cstOne); + AtenSliceScatterOp::create(b, yTy, Y_prev, H_new_unsqueezed, cstZero, + loopIndex, loopIndexPlusOne, cstOne); - b.create(loopConditionTrue, - ValueRange({Y_new, H_new, C_new})); + PrimLoopConditionOp::create(b, loopConditionTrue, + ValueRange({Y_new, H_new, C_new})); } LstmLayerOutput output; output.Y = loop.getResult(0); @@ -711,18 +714,18 @@ LogicalResult OnnxLstmExpander(OpBinder binder, Value B; if (binder.tensorOperandAtIndex(B, 3)) { - Value none = b.create(); - Value cstHiddenx8 = b.create( - b.getType(), b.getI64IntegerAttr(8 * hidden_size)); - Value cstNumDir = b.create( - b.getType(), b.getI64IntegerAttr(num_directions)); + Value none = ConstantNoneOp::create(b); + Value cstHiddenx8 = ConstantIntOp::create( + b, b.getType(), b.getI64IntegerAttr(8 * hidden_size)); + Value cstNumDir = ConstantIntOp::create( + b, b.getType(), b.getI64IntegerAttr(num_directions)); auto BType = b.getType( llvm::SmallVector{num_directions, 8 * hidden_size}, cast(W.getType()).getDtype()); - Value zerosShapeList = b.create( - b.getType(b.getType()), + Value zerosShapeList = PrimListConstructOp::create( + b, b.getType(b.getType()), SmallVector{cstNumDir, cstHiddenx8}); - B = b.create(BType, zerosShapeList, none, none, none, none); + B = AtenZerosOp::create(b, BType, zerosShapeList, none, none, none, none); } LstmActivations activations, activationsRev; @@ -793,11 +796,12 @@ LogicalResult OnnxLstmExpander(OpBinder binder, } else { Value x_input_size = Torch::getTensorDimSize(rewriter, X, 2); Value w_input_size = - b.create(loc, b.getI64IntegerAttr(wTy.getSizes()[2])); + ConstantIntOp::create(b, loc, b.getI64IntegerAttr(wTy.getSizes()[2])); - auto eq = b.create(loc, x_input_size, w_input_size); - rewriter.create( - loc, eq, rewriter.getStringAttr("The input_size of W must equal X.")); + auto eq = AtenEqIntOp::create(b, loc, x_input_size, w_input_size); + RuntimeAssertOp::create( + rewriter, loc, eq, + rewriter.getStringAttr("The input_size of W must equal X.")); } Value W_forward = getDirection(b, 0, W); @@ -814,17 +818,17 @@ LogicalResult OnnxLstmExpander(OpBinder binder, auto intType = b.getType(); Value cstNumDirections = - b.create(intType, b.getI64IntegerAttr(num_directions)); + ConstantIntOp::create(b, intType, b.getI64IntegerAttr(num_directions)); Value cstHiddenSize = - b.create(intType, b.getI64IntegerAttr(hidden_size)); - Value cstNone = b.create(); - Value cstZero = b.create(intType, b.getI64IntegerAttr(0)); - Value cstOne = b.create(intType, b.getI64IntegerAttr(1)); + ConstantIntOp::create(b, intType, b.getI64IntegerAttr(hidden_size)); + Value cstNone = ConstantNoneOp::create(b); + Value cstZero = ConstantIntOp::create(b, intType, b.getI64IntegerAttr(0)); + Value cstOne = ConstantIntOp::create(b, intType, b.getI64IntegerAttr(1)); auto hTy = getTensorTypeFromShapeValues( {cstNumDirections, batchSize, cstHiddenSize}, xTy.getDtype()); - Value hShape = b.create( - b.getType(intType), + Value hShape = PrimListConstructOp::create( + b, b.getType(intType), ValueRange({cstNumDirections, batchSize, cstHiddenSize})); Value cstDtype = getDtypeIntValueForType(rewriter, loc, xTy.getDtype()); @@ -832,8 +836,8 @@ LogicalResult OnnxLstmExpander(OpBinder binder, Value initial_h; if (binder.tensorOperandAtIndex(initial_h, 5)) { // default created for layout 0 - initial_h = - b.create(hTy, hShape, cstDtype, cstNone, cstNone, cstNone); + initial_h = AtenZerosOp::create(b, hTy, hShape, cstDtype, cstNone, cstNone, + cstNone); } else { if (layout == 1) initial_h = StaticTranspose(b, initial_h, 0, 1); @@ -842,8 +846,8 @@ LogicalResult OnnxLstmExpander(OpBinder binder, Value initial_c; if (binder.tensorOperandAtIndex(initial_c, 6)) { // default created for layout 0 - initial_c = - b.create(hTy, hShape, cstDtype, cstNone, cstNone, cstNone); + initial_c = AtenZerosOp::create(b, hTy, hShape, cstDtype, cstNone, cstNone, + cstNone); } else { if (layout == 1) initial_c = StaticTranspose(b, initial_c, 0, 1); @@ -871,7 +875,7 @@ LogicalResult OnnxLstmExpander(OpBinder binder, LstmWeights weights, weightsRev; // weights and biases auto intConst = [&](int64_t val) { - return b.create(intType, b.getI64IntegerAttr(val)); + return ConstantIntOp::create(b, intType, b.getI64IntegerAttr(val)); }; // split B into Wb and Rb @@ -881,33 +885,33 @@ LogicalResult OnnxLstmExpander(OpBinder binder, auto biasType = b.getType( llvm::SmallVector{hidden_size * 4}, wTy.getDtype()); // forward - Value Wb = b.create(biasType, - /*input=*/B_forward, - /*dim=*/cstZero, - /*start=*/cstZero, - /*end=*/inputWeightsEndIdx, - /*step=*/cstOne); - Value Rb = b.create(biasType, - /*input=*/B_forward, - /*dim=*/cstZero, - /*start=*/recurrentWeightsStartIdx, - /*end=*/recurrentWeightsEndIdx, - /*step=*/cstOne); + Value Wb = AtenSliceTensorOp::create(b, biasType, + /*input=*/B_forward, + /*dim=*/cstZero, + /*start=*/cstZero, + /*end=*/inputWeightsEndIdx, + /*step=*/cstOne); + Value Rb = AtenSliceTensorOp::create(b, biasType, + /*input=*/B_forward, + /*dim=*/cstZero, + /*start=*/recurrentWeightsStartIdx, + /*end=*/recurrentWeightsEndIdx, + /*step=*/cstOne); Value Wb_reverse, Rb_reverse; if (isBidirectional) { // reverse - Wb_reverse = b.create(biasType, - /*input=*/B_reverse, - /*dim=*/cstZero, - /*start=*/cstZero, - /*end=*/inputWeightsEndIdx, - /*step=*/cstOne); - Rb_reverse = b.create(biasType, - /*input=*/B_reverse, - /*dim=*/cstZero, - /*start=*/recurrentWeightsStartIdx, - /*end=*/recurrentWeightsEndIdx, - /*step=*/cstOne); + Wb_reverse = AtenSliceTensorOp::create(b, biasType, + /*input=*/B_reverse, + /*dim=*/cstZero, + /*start=*/cstZero, + /*end=*/inputWeightsEndIdx, + /*step=*/cstOne); + Rb_reverse = AtenSliceTensorOp::create(b, biasType, + /*input=*/B_reverse, + /*dim=*/cstZero, + /*start=*/recurrentWeightsStartIdx, + /*end=*/recurrentWeightsEndIdx, + /*step=*/cstOne); } // gate splitting @@ -937,8 +941,8 @@ LogicalResult OnnxLstmExpander(OpBinder binder, }; auto sliceGateBias = [&](Value startIdx, Value endIdx, Value WoB) { - return b.create(gateBiasType, WoB, cstZero, startIdx, - endIdx, cstOne); + return AtenSliceTensorOp::create(b, gateBiasType, WoB, cstZero, startIdx, + endIdx, cstOne); }; std::tie(weights.Wb_i, weights.Wb_o, weights.Wb_f, weights.Wb_c) = sliceIOFC(sliceGateBias, Wb); @@ -948,8 +952,8 @@ LogicalResult OnnxLstmExpander(OpBinder binder, weightsRev.Wb_c) = sliceIOFC(sliceGateBias, Wb_reverse); auto sliceGateBiasR = [&](Value startIdx, Value endIdx, Value WoB) { - return b.create(gateBiasType, WoB, cstZero, startIdx, - endIdx, cstOne); + return AtenSliceTensorOp::create(b, gateBiasType, WoB, cstZero, startIdx, + endIdx, cstOne); }; std::tie(weights.Rb_i, weights.Rb_o, weights.Rb_f, weights.Rb_c) = sliceIOFC(sliceGateBiasR, Rb); @@ -959,8 +963,8 @@ LogicalResult OnnxLstmExpander(OpBinder binder, weightsRev.Rb_c) = sliceIOFC(sliceGateBiasR, Rb_reverse); auto sliceGateWeightsIH = [&](Value startIdx, Value endIdx, Value WoB) { - return b.create(gateWeightsTypeIH, WoB, cstZero, - startIdx, endIdx, cstOne); + return AtenSliceTensorOp::create(b, gateWeightsTypeIH, WoB, cstZero, + startIdx, endIdx, cstOne); }; std::tie(weights.W_i, weights.W_o, weights.W_f, weights.W_c) = sliceIOFC(sliceGateWeightsIH, W_forward); @@ -970,8 +974,8 @@ LogicalResult OnnxLstmExpander(OpBinder binder, sliceIOFC(sliceGateWeightsIH, W_reverse); auto sliceGateWeightsHH = [&](Value startIdx, Value endIdx, Value WoB) { - return b.create(gateWeightsTypeHH, WoB, cstZero, - startIdx, endIdx, cstOne); + return AtenSliceTensorOp::create(b, gateWeightsTypeHH, WoB, cstZero, + startIdx, endIdx, cstOne); }; std::tie(weights.R_i, weights.R_o, weights.R_f, weights.R_c) = @@ -1002,17 +1006,17 @@ LogicalResult OnnxLstmExpander(OpBinder binder, auto Y_res_type = getTensorTypeFromShapeValues( {seqLen, cstNumDirections, batchSize, cstHiddenSize}, YallDtype); - Value Y_h_forward = - b.create(Y_h_Y_c_uni_type, lstmLayerOutput.Y_h, cstZero); + Value Y_h_forward = AtenUnsqueezeOp::create(b, Y_h_Y_c_uni_type, + lstmLayerOutput.Y_h, cstZero); - Value Y_c_forward = - b.create(Y_h_Y_c_uni_type, lstmLayerOutput.Y_c, cstZero); + Value Y_c_forward = AtenUnsqueezeOp::create(b, Y_h_Y_c_uni_type, + lstmLayerOutput.Y_c, cstZero); // unsqueeze num_directions dim1 of Y // to create the onnx.LSTM output shape [seq_length, num_directions, // batch_size, hidden_size] Value Y_forward = - b.create(Y_uni_type, lstmLayerOutput.Y, cstOne); + AtenUnsqueezeOp::create(b, Y_uni_type, lstmLayerOutput.Y, cstOne); Y_result = Y_forward; Y_h_result = Y_h_forward; @@ -1025,42 +1029,42 @@ LogicalResult OnnxLstmExpander(OpBinder binder, Y_reverse, Y_output_list, Y_h_output_list, Y_c_output_list; LstmLayerOutput revLstmLayerOutput; if (isBidirectional) { - dim0 = b.create(b.getType(intType), - SmallVector{cstZero}); - X_reverse = b.create(xTy, X, dim0); // flip along seq_len dim + dim0 = PrimListConstructOp::create(b, b.getType(intType), + SmallVector{cstZero}); + X_reverse = AtenFlipOp::create(b, xTy, X, dim0); // flip along seq_len dim revLstmLayerOutput = lstm_layer(rewriter, loc, X_reverse, initial_h_reverse, initial_c_reverse, weightsRev, activationsRev); // unsqueeze Y_rev, Y_h_rev, Y_c_rev - Y_h_reverse = b.create(Y_h_Y_c_uni_type, - revLstmLayerOutput.Y_h, cstZero); - Y_c_reverse = b.create(Y_h_Y_c_uni_type, - revLstmLayerOutput.Y_c, cstZero); + Y_h_reverse = AtenUnsqueezeOp::create(b, Y_h_Y_c_uni_type, + revLstmLayerOutput.Y_h, cstZero); + Y_c_reverse = AtenUnsqueezeOp::create(b, Y_h_Y_c_uni_type, + revLstmLayerOutput.Y_c, cstZero); Y_reverse_unflipped = - b.create(Y_uni_type, revLstmLayerOutput.Y, cstOne); + AtenUnsqueezeOp::create(b, Y_uni_type, revLstmLayerOutput.Y, cstOne); // flip Y_rev on dim 0 [seq_len] - Y_reverse = b.create(Y_uni_type, Y_reverse_unflipped, dim0); + Y_reverse = AtenFlipOp::create(b, Y_uni_type, Y_reverse_unflipped, dim0); // Concat forward and reverse results on dim 1 Y_output_list = - b.create(b.getType(Y_uni_type), - SmallVector{Y_forward, Y_reverse}); - Y_result = b.create(Y_res_type, Y_output_list, cstOne); + PrimListConstructOp::create(b, b.getType(Y_uni_type), + SmallVector{Y_forward, Y_reverse}); + Y_result = AtenCatOp::create(b, Y_res_type, Y_output_list, cstOne); // Concat forward and reverse results on dim 0 - Y_h_output_list = b.create( - b.getType(Y_h_Y_c_uni_type), + Y_h_output_list = PrimListConstructOp::create( + b, b.getType(Y_h_Y_c_uni_type), SmallVector{Y_h_forward, Y_h_reverse}); Y_h_result = - b.create(Y_h_Y_c_res_type, Y_h_output_list, cstZero); + AtenCatOp::create(b, Y_h_Y_c_res_type, Y_h_output_list, cstZero); - Y_c_output_list = b.create( - b.getType(Y_h_Y_c_uni_type), + Y_c_output_list = PrimListConstructOp::create( + b, b.getType(Y_h_Y_c_uni_type), SmallVector{Y_c_forward, Y_c_reverse}); Y_c_result = - b.create(Y_h_Y_c_res_type, Y_c_output_list, cstZero); + AtenCatOp::create(b, Y_h_Y_c_res_type, Y_c_output_list, cstZero); } if (layout == 1) { @@ -1124,46 +1128,46 @@ Value gru_cell(ImplicitLocOpBuilder &b, Value Xt, Value H_prev, auto hTy = cast(H_prev.getType()); auto intType = b.getType(); - Value cstOne = b.create(intType, b.getI64IntegerAttr(1)); + Value cstOne = ConstantIntOp::create(b, intType, b.getI64IntegerAttr(1)); - Value z_w = b.create(hTy, Xt, weights.Wz, weights.Wbz); - Value z_r = b.create(hTy, H_prev, weights.Rz, weights.Rbz); - Value z_pre = b.create(hTy, z_w, z_r, cstOne); + Value z_w = AtenLinearOp::create(b, hTy, Xt, weights.Wz, weights.Wbz); + Value z_r = AtenLinearOp::create(b, hTy, H_prev, weights.Rz, weights.Rbz); + Value z_pre = AtenAddTensorOp::create(b, hTy, z_w, z_r, cstOne); Value zt = createActivationByName(b, activations.f, z_pre); - Value r_w = b.create(hTy, Xt, weights.Wr, weights.Wbr); - Value r_r = b.create(hTy, H_prev, weights.Rr, weights.Rbr); - Value r_pre = b.create(hTy, r_w, r_r, cstOne); + Value r_w = AtenLinearOp::create(b, hTy, Xt, weights.Wr, weights.Wbr); + Value r_r = AtenLinearOp::create(b, hTy, H_prev, weights.Rr, weights.Rbr); + Value r_pre = AtenAddTensorOp::create(b, hTy, r_w, r_r, cstOne); Value rt = createActivationByName(b, activations.f, r_pre); - Value h_w = b.create(hTy, Xt, weights.Wh, weights.Wbh); + Value h_w = AtenLinearOp::create(b, hTy, Xt, weights.Wh, weights.Wbh); Value h_r; if (linear_before_reset) { // when linear_before_reset = 1, multiply r with H_prev to reset // before applying linear layer Value h_linear = - b.create(hTy, H_prev, weights.Rh, weights.Rbh); - h_r = b.create(hTy, h_linear, rt); + AtenLinearOp::create(b, hTy, H_prev, weights.Rh, weights.Rbh); + h_r = AtenMulTensorOp::create(b, hTy, h_linear, rt); } else { // otherwise, multiply first and then apply linear layer - Value h_reset = b.create(hTy, H_prev, rt); - h_r = b.create(hTy, h_reset, weights.Rh, weights.Rbh); + Value h_reset = AtenMulTensorOp::create(b, hTy, H_prev, rt); + h_r = AtenLinearOp::create(b, hTy, h_reset, weights.Rh, weights.Rbh); } - Value h_pre = b.create(hTy, h_w, h_r, cstOne); + Value h_pre = AtenAddTensorOp::create(b, hTy, h_w, h_r, cstOne); Value ht = createActivationByName(b, activations.g, h_pre); // Create a constant tensor filled with ones, matching the shape of zt - Value cstNone = b.create(); + Value cstNone = ConstantNoneOp::create(b); int64_t typeInt = (int64_t)getScalarTypeForType(hTy.getDtype()); - Value dtype = b.create(b.getI64IntegerAttr(typeInt)); - Value ones = b.create( - hTy, zt, dtype, /*layout=*/cstNone, + Value dtype = ConstantIntOp::create(b, b.getI64IntegerAttr(typeInt)); + Value ones = Torch::AtenOnesLikeOp::create( + b, hTy, zt, dtype, /*layout=*/cstNone, /*device=*/cstNone, /*pin_memory=*/cstNone, /*memory_format=*/cstNone); - Value one_minus_zt = b.create(hTy, ones, zt, cstOne); - Value ht_scaled = b.create(hTy, one_minus_zt, ht); - Value H_prev_zt = b.create(hTy, H_prev, zt); - Value H_new = b.create(hTy, ht_scaled, H_prev_zt, cstOne); + Value one_minus_zt = AtenSubTensorOp::create(b, hTy, ones, zt, cstOne); + Value ht_scaled = AtenMulTensorOp::create(b, hTy, one_minus_zt, ht); + Value H_prev_zt = AtenMulTensorOp::create(b, hTy, H_prev, zt); + Value H_new = AtenAddTensorOp::create(b, hTy, ht_scaled, H_prev_zt, cstOne); return H_new; } @@ -1187,38 +1191,38 @@ GruLayerOutput gru_layer(ImplicitLocOpBuilder &b, Value X, Value initial_h, auto intType = b.getType(); - Value cstNone = b.create(); - Value cstZero = b.create(intType, b.getI64IntegerAttr(0)); - Value cstOne = b.create(intType, b.getI64IntegerAttr(1)); + Value cstNone = ConstantNoneOp::create(b); + Value cstZero = ConstantIntOp::create(b, intType, b.getI64IntegerAttr(0)); + Value cstOne = ConstantIntOp::create(b, intType, b.getI64IntegerAttr(1)); Value cstSeqLen = - b.create(intType, b.getI64IntegerAttr(seq_len)); + ConstantIntOp::create(b, intType, b.getI64IntegerAttr(seq_len)); Value cstBatchSize = - b.create(intType, b.getI64IntegerAttr(batch_size)); + ConstantIntOp::create(b, intType, b.getI64IntegerAttr(batch_size)); Value cstHiddenSize = - b.create(intType, b.getI64IntegerAttr(hidden_size)); + ConstantIntOp::create(b, intType, b.getI64IntegerAttr(hidden_size)); auto yTy = b.getType( SmallVector{seq_len, batch_size, hidden_size}, hTy.getDtype()); - auto YShapeList = b.create( - b.getType(intType), + auto YShapeList = PrimListConstructOp::create( + b, b.getType(intType), ValueRange({cstSeqLen, cstBatchSize, cstHiddenSize})); int64_t hDtypeInt = static_cast(getScalarTypeForType(hTy.getDtype())); - Value hDtypeIntVal = b.create(b.getI64IntegerAttr(hDtypeInt)); + Value hDtypeIntVal = ConstantIntOp::create(b, b.getI64IntegerAttr(hDtypeInt)); - Value Y_initial = b.create(yTy, YShapeList, hDtypeIntVal, - cstNone, cstNone, cstNone); + Value Y_initial = AtenZerosOp::create(b, yTy, YShapeList, hDtypeIntVal, + cstNone, cstNone, cstNone); Value maxTripCount = cstSeqLen; - Value loopConditionTrue = b.create(true); + Value loopConditionTrue = ConstantBoolOp::create(b, true); Type loopIndexType = intType; - auto loop = b.create(TypeRange({yTy, hTy}), maxTripCount, - loopConditionTrue, - ValueRange({Y_initial, initial_h})); + auto loop = + PrimLoopOp::create(b, TypeRange({yTy, hTy}), maxTripCount, + loopConditionTrue, ValueRange({Y_initial, initial_h})); { OpBuilder::InsertionGuard guard(b); @@ -1233,7 +1237,7 @@ GruLayerOutput gru_layer(ImplicitLocOpBuilder &b, Value X, Value initial_h, auto XtType = b.getType( llvm::SmallVector{batch_size, input_size}, xTy.getDtype()); - Value Xt = b.create(XtType, X, cstZero, loopIndex); + Value Xt = AtenSelectIntOp::create(b, XtType, X, cstZero, loopIndex); Value H_new = gru_cell(b, Xt, H_prev, weights, activations, linear_before_reset); @@ -1241,15 +1245,15 @@ GruLayerOutput gru_layer(ImplicitLocOpBuilder &b, Value X, Value initial_h, Type hTyUnsqueezed = b.getType( llvm::SmallVector{1, batch_size, hidden_size}, hTy.getDtype()); Value H_new_unsqueezed = - b.create(hTyUnsqueezed, H_new, cstZero); + AtenUnsqueezeOp::create(b, hTyUnsqueezed, H_new, cstZero); - auto loopIndexPlusOne = b.create(intType, loopIndex, cstOne); + auto loopIndexPlusOne = AtenAddIntOp::create(b, intType, loopIndex, cstOne); Value Y_new = - b.create(yTy, Y_prev, H_new_unsqueezed, cstZero, - loopIndex, loopIndexPlusOne, cstOne); + AtenSliceScatterOp::create(b, yTy, Y_prev, H_new_unsqueezed, cstZero, + loopIndex, loopIndexPlusOne, cstOne); - b.create(loopConditionTrue, - ValueRange({Y_new, H_new})); + PrimLoopConditionOp::create(b, loopConditionTrue, + ValueRange({Y_new, H_new})); } GruLayerOutput output; @@ -1265,9 +1269,9 @@ LogicalResult OnnxGruExpander(OpBinder binder, mlir::ImplicitLocOpBuilder b(loc, rewriter); auto intType = b.getType(); - Value cstNone = b.create(); - Value cstZero = b.create(intType, b.getI64IntegerAttr(0)); - Value cstOne = b.create(intType, b.getI64IntegerAttr(1)); + Value cstNone = ConstantNoneOp::create(b); + Value cstZero = ConstantIntOp::create(b, intType, b.getI64IntegerAttr(0)); + Value cstOne = ConstantIntOp::create(b, intType, b.getI64IntegerAttr(1)); // Binding arguments ValueTensorType yTy, Y_hType; @@ -1358,17 +1362,17 @@ LogicalResult OnnxGruExpander(OpBinder binder, if (binder.tensorOperandAtIndex(initial_h, 5)) { Value cstNumDirections = - b.create(intType, b.getI64IntegerAttr(num_directions)); + ConstantIntOp::create(b, intType, b.getI64IntegerAttr(num_directions)); Value cstBatchSize = - b.create(intType, b.getI64IntegerAttr(batch_size)); + ConstantIntOp::create(b, intType, b.getI64IntegerAttr(batch_size)); Value cstHiddenSize = - b.create(intType, b.getI64IntegerAttr(hidden_size)); - Value hShape = b.create( - b.getType(intType), + ConstantIntOp::create(b, intType, b.getI64IntegerAttr(hidden_size)); + Value hShape = PrimListConstructOp::create( + b, b.getType(intType), ValueRange({cstNumDirections, cstBatchSize, cstHiddenSize})); Value cstDtype = getDtypeIntValueForType(rewriter, loc, xTy.getDtype()); - initial_h = - b.create(hTy, hShape, cstDtype, cstNone, cstNone, cstNone); + initial_h = AtenZerosOp::create(b, hTy, hShape, cstDtype, cstNone, cstNone, + cstNone); } else { if (layout == 1) { initial_h = StaticTranspose(b, initial_h, 0, 1); @@ -1376,7 +1380,7 @@ LogicalResult OnnxGruExpander(OpBinder binder, } if (binder.tensorOperandAtIndex(sequence_lens, 4)) - sequence_lens = b.create(); + sequence_lens = ConstantNoneOp::create(b); float clip; if (!binder.f32FloatAttr(clip, "clip") && clip != 0.0f) @@ -1394,13 +1398,14 @@ LogicalResult OnnxGruExpander(OpBinder binder, if (B == nullptr) { SmallVector BShape = {num_directions, 6 * hidden_size}; SmallVector BShapeListContents = { - b.create(intType, b.getI64IntegerAttr(num_directions)), - b.create(intType, b.getI64IntegerAttr(6 * hidden_size))}; - Value BShapeList = b.create( - b.getType(intType), BShapeListContents); + ConstantIntOp::create(b, intType, b.getI64IntegerAttr(num_directions)), + ConstantIntOp::create(b, intType, + b.getI64IntegerAttr(6 * hidden_size))}; + Value BShapeList = PrimListConstructOp::create( + b, b.getType(intType), BShapeListContents); auto BType = b.getType(BShape, wTy.getDtype()); - B = b.create(BType, BShapeList, cstXDtype, cstNone, - cstNone, cstNone); + B = Torch::AtenZerosOp::create(b, BType, BShapeList, cstXDtype, cstNone, + cstNone, cstNone); } Value W_forward = getDirection(b, 0, W); @@ -1417,14 +1422,14 @@ LogicalResult OnnxGruExpander(OpBinder binder, SmallVector slices; for (int64_t i = 0; i < numSlices; ++i) { Value start = - b.create(intType, b.getI64IntegerAttr(i * sliceSize)); - Value end = b.create( - intType, b.getI64IntegerAttr((i + 1) * sliceSize)); - - Value slice = b.create(sliceType, tensor, - cstZero, // dim to slice on - start, end, - cstOne // step + ConstantIntOp::create(b, intType, b.getI64IntegerAttr(i * sliceSize)); + Value end = ConstantIntOp::create( + b, intType, b.getI64IntegerAttr((i + 1) * sliceSize)); + + Value slice = AtenSliceTensorOp::create(b, sliceType, tensor, + cstZero, // dim to slice on + start, end, + cstOne // step ); slices.push_back(slice); @@ -1470,13 +1475,13 @@ LogicalResult OnnxGruExpander(OpBinder binder, Y_final = cstNone; } else { if (layout == 0) { - Y_final = b.create(yTy, gruLayerOutput.Y, cstOne); + Y_final = AtenUnsqueezeOp::create(b, yTy, gruLayerOutput.Y, cstOne); } else { Type yTy_original = b.getType( llvm::SmallVector{seq_len, 1, batch_size, hidden_size}, yTy.getDtype()); Y_final = - b.create(yTy_original, gruLayerOutput.Y, cstOne); + AtenUnsqueezeOp::create(b, yTy_original, gruLayerOutput.Y, cstOne); Y_final = StaticTranspose(b, Y_final, 1, 2); Y_final = StaticTranspose(b, Y_final, 0, 1); } @@ -1488,13 +1493,13 @@ LogicalResult OnnxGruExpander(OpBinder binder, } else { if (layout == 0) { Y_h_final = - b.create(Y_hType, gruLayerOutput.Y_h, cstZero); + AtenUnsqueezeOp::create(b, Y_hType, gruLayerOutput.Y_h, cstZero); } else { Type y_hTy_original = b.getType( llvm::SmallVector{1, batch_size, hidden_size}, Y_hType.getDtype()); - Y_h_final = b.create(y_hTy_original, gruLayerOutput.Y_h, - cstZero); + Y_h_final = AtenUnsqueezeOp::create(b, y_hTy_original, gruLayerOutput.Y_h, + cstZero); Y_h_final = StaticTranspose(b, Y_h_final, 0, 1); } } diff --git a/lib/Conversion/TorchOnnxToTorch/Utils.cpp b/lib/Conversion/TorchOnnxToTorch/Utils.cpp index 848ce461e9c1..e14932edb63e 100644 --- a/lib/Conversion/TorchOnnxToTorch/Utils.cpp +++ b/lib/Conversion/TorchOnnxToTorch/Utils.cpp @@ -20,11 +20,11 @@ Value mlir::torch::onnx_c::createConstantIntList( ArrayRef cstInput) { SmallVector cstValue; for (int64_t i : cstInput) { - cstValue.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i))); + cstValue.push_back(Torch::ConstantIntOp::create( + rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(i))); } - return rewriter.create( - binder.getLoc(), + return Torch::PrimListConstructOp::create( + rewriter, binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstValue); } @@ -109,12 +109,12 @@ LogicalResult mlir::torch::onnx_c::createTorchTransposeOp( if (failed(getTransposedType(cast(input.getType()), dimA, dimB, transposedType))) return failure(); - Value cstDimA = rewriter.create( - loc, rewriter.getI64IntegerAttr(dimA)); - Value cstDimB = rewriter.create( - loc, rewriter.getI64IntegerAttr(dimB)); - transposed = rewriter.create( - loc, transposedType, input, cstDimA, cstDimB); + Value cstDimA = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(dimA)); + Value cstDimB = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(dimB)); + transposed = Torch::AtenTransposeIntOp::create(rewriter, loc, transposedType, + input, cstDimA, cstDimB); return success(); } @@ -127,19 +127,19 @@ LogicalResult mlir::torch::onnx_c::createTorchPermuteOp( permuteDims, permutedType))) return failure(); Value permuteDimsList = createConstantIntList(binder, rewriter, permuteDims); - permuted = rewriter.create(loc, permutedType, input, - permuteDimsList); + permuted = Torch::AtenPermuteOp::create(rewriter, loc, permutedType, input, + permuteDimsList); return success(); } Value mlir::torch::onnx_c::createActivationByName(ImplicitLocOpBuilder &b, StringRef name, Value input) { if (name == "Sigmoid") - return b.create(input.getType(), input); + return Torch::AtenSigmoidOp::create(b, input.getType(), input); if (name == "Tanh") - return b.create(input.getType(), input); + return Torch::AtenTanhOp::create(b, input.getType(), input); if (name == "Relu") - return b.create(input.getType(), input); + return Torch::AtenReluOp::create(b, input.getType(), input); llvm_unreachable("Unsupported activation function"); } @@ -158,8 +158,8 @@ LogicalResult mlir::torch::onnx_c::extractPerTensorQuantizationArguments( if (!check(inScale) || !check(inZeroPoint)) return failure(); - Value emptyList = rewriter.create( - loc, + Value emptyList = Torch::PrimListConstructOp::create( + rewriter, loc, rewriter.getType(rewriter.getType()), ValueRange{}); auto extract = [&rewriter, &loc, &emptyList](Value v) { @@ -167,14 +167,14 @@ LogicalResult mlir::torch::onnx_c::extractPerTensorQuantizationArguments( if (!vTy.getSizes().empty()) { vTy = rewriter.getType(ArrayRef({}), vTy.getOptionalDtype()); - v = rewriter.create(loc, vTy, v, emptyList); + v = Torch::AtenReshapeOp::create(rewriter, loc, vTy, v, emptyList); } Type extractTy = rewriter.getType(); if (isa(vTy.getDtype())) extractTy = rewriter.getType(); - return rewriter.create(loc, extractTy, v); + return Torch::AtenItemOp::create(rewriter, loc, extractTy, v); }; outScale = extract(inScale); @@ -191,14 +191,13 @@ LogicalResult mlir::torch::onnx_c::createDequantizeTensor( return failure(); Torch::ValueTensorType makeTensorTy = getQTorchTypeFromTorchIntType(inputTy); - Value quantizedInput = - rewriter.create( - loc, makeTensorTy, input, scale, zeroPoint); + Value quantizedInput = Torch::Aten_MakePerTensorQuantizedTensorOp::create( + rewriter, loc, makeTensorTy, input, scale, zeroPoint); Torch::ValueTensorType resultTy = rewriter.getType( inputTy.getSizes(), rewriter.getF32Type()); - output = rewriter.create(loc, resultTy, - quantizedInput); + output = Torch::AtenDequantizeSelfOp::create(rewriter, loc, resultTy, + quantizedInput); return success(); } diff --git a/lib/Conversion/TorchToArith/TorchToArith.cpp b/lib/Conversion/TorchToArith/TorchToArith.cpp index 69d585c69ba4..17614f95ea16 100644 --- a/lib/Conversion/TorchToArith/TorchToArith.cpp +++ b/lib/Conversion/TorchToArith/TorchToArith.cpp @@ -41,7 +41,7 @@ class ConvertAtenDimOp : public OpConversionPattern { matchAndRewrite(AtenDimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto rank = - rewriter.create(op->getLoc(), adaptor.getSelf()); + tensor::RankOp::create(rewriter, op->getLoc(), adaptor.getSelf()); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), rank); return success(); @@ -96,8 +96,8 @@ class ConvertAtenNegIntOp : public OpConversionPattern { Value a = adaptor.getA(); rewriter.replaceOpWithNewOp( op, - rewriter.create(op.getLoc(), /*value=*/0, - /*bitwidth=*/64), + arith::ConstantIntOp::create(rewriter, op.getLoc(), /*value=*/0, + /*bitwidth=*/64), a); return success(); } @@ -119,7 +119,7 @@ class ConvertAtenUnaryOpToFloatMathOp : public OpConversionPattern { this->getTypeConverter()->convertType(op->getResult(0).getType()); if (!isa(input.getType())) input = convertScalarToDtype(rewriter, loc, input, rewriter.getF64Type()); - Value result = rewriter.create(loc, input); + Value result = UnaryOp::create(rewriter, loc, input); rewriter.replaceOp(op, convertScalarToDtype(rewriter, loc, result, resultType)); return success(); @@ -347,7 +347,7 @@ class ConvertAtenAnyOrAllBoolOp : public OpConversionPattern { rewriter, loc, this->getTypeConverter(), inputListTorchBool); result = inputList[0]; for (unsigned i = 1; i < inputList.size(); i++) - result = rewriter.create(loc, result, inputList[i]); + result = BinOp::create(rewriter, loc, result, inputList[i]); rewriter.replaceOp(op, result); return success(); } @@ -385,15 +385,15 @@ class ConvertAtenBoolLikeOp : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Type inputType = adaptor.getA().getType(); - Value cstZero = rewriter.create( - loc, rewriter.getZeroAttr(inputType)); + Value cstZero = arith::ConstantOp::create(rewriter, loc, + rewriter.getZeroAttr(inputType)); Value cstTrue = - rewriter.create(loc, rewriter.getBoolAttr(true)); + arith::ConstantOp::create(rewriter, loc, rewriter.getBoolAttr(true)); Value cstFalse = - rewriter.create(loc, rewriter.getBoolAttr(false)); + arith::ConstantOp::create(rewriter, loc, rewriter.getBoolAttr(false)); Value cmpPred; - cmpPred = rewriter.create(loc, Pred, adaptor.getA(), cstZero); + cmpPred = CmpOpTy::create(rewriter, loc, Pred, adaptor.getA(), cstZero); rewriter.replaceOpWithNewOp(op, cmpPred, cstTrue, cstFalse); return success(); diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index eb08786f7982..1cbb4aa99288 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -48,9 +48,9 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, auto input = adaptor.getSelf(); RankedTensorType inputType = cast(input.getType()); - Value zero = rewriter.create(loc, 0); - Value one = rewriter.create(loc, 1); - Value negone = rewriter.create(loc, -1); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value one = arith::ConstantIndexOp::create(rewriter, loc, 1); + Value negone = arith::ConstantIndexOp::create(rewriter, loc, -1); if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return op->emitError("unimplemented: dim is not constant"); @@ -83,34 +83,35 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, end = dimSize; } else { end = castIntToIndex(rewriter, loc, end); - Value endcmp = rewriter.create( - loc, arith::CmpIPredicate::slt, end, zero); - Value endadd = rewriter.create(loc, end, dimSize); - end = rewriter.create(loc, endcmp, endadd, end); - endcmp = rewriter.create(loc, arith::CmpIPredicate::slt, end, - zero); - end = rewriter.create(loc, endcmp, negone, end); - endcmp = rewriter.create(loc, arith::CmpIPredicate::sgt, end, - dimSize); - end = rewriter.create(loc, endcmp, dimSize, end); + Value endcmp = arith::CmpIOp::create(rewriter, loc, + arith::CmpIPredicate::slt, end, zero); + Value endadd = arith::AddIOp::create(rewriter, loc, end, dimSize); + end = arith::SelectOp::create(rewriter, loc, endcmp, endadd, end); + endcmp = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, + end, zero); + end = arith::SelectOp::create(rewriter, loc, endcmp, negone, end); + endcmp = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sgt, + end, dimSize); + end = arith::SelectOp::create(rewriter, loc, endcmp, dimSize, end); } // Slice logic: resultSize = floordiv(end - start + step - 1, step) resultShape = getTensorSizes(rewriter, loc, input); - Value len = rewriter.create(loc, end, start); + Value len = arith::SubIOp::create(rewriter, loc, end, start); // We check the difference between start and end to determine the total size: - Value stepcmp = rewriter.create(loc, arith::CmpIPredicate::sge, - stepIndex, zero); - Value stepsign = rewriter.create(loc, stepcmp, one, negone); - Value resultSize = rewriter.create(loc, len, stepIndex); - resultSize = rewriter.create(loc, resultSize, stepsign); - resultSize = rewriter.create(loc, resultSize, stepIndex); + Value stepcmp = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sge, stepIndex, zero); + Value stepsign = arith::SelectOp::create(rewriter, loc, stepcmp, one, negone); + Value resultSize = arith::AddIOp::create(rewriter, loc, len, stepIndex); + resultSize = arith::SubIOp::create(rewriter, loc, resultSize, stepsign); + resultSize = + arith::FloorDivSIOp::create(rewriter, loc, resultSize, stepIndex); // Clamp the size to [0, ...]: - Value szcmp = rewriter.create(loc, arith::CmpIPredicate::slt, - resultSize, zero); - resultSize = rewriter.create(loc, szcmp, zero, resultSize); + Value szcmp = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, + resultSize, zero); + resultSize = arith::SelectOp::create(rewriter, loc, szcmp, zero, resultSize); resultShape[dim] = resultSize; strides.resize(inputType.getRank(), one); @@ -174,12 +175,12 @@ class ConvertAtenReflectionPad1dOp // Lambda Unitility Functions // Create an Integer expression of x + y auto createIAdd = [&](Value x, Value y) { - return rewriter.create(loc, x, y); + return arith::AddIOp::create(rewriter, loc, x, y); }; // Create an integer expression of x - y auto createISub = [&](Value x, Value y) { - return rewriter.create(loc, x, y); + return arith::SubIOp::create(rewriter, loc, x, y); }; enum PadLocation { PAD_LEFT = 0, PAD_RIGHT = 1, PAD_CENTER = 2 }; @@ -248,8 +249,8 @@ class ConvertAtenReflectionPad1dOp extractShape[lastDim] = tileWidth[padPosition]; SmallVector extractOffsets(numDims, zero); extractOffsets[lastDim] = extractOffset[padPosition]; - Value tile = rewriter.create( - loc, input, extractOffsets, extractShape, allOneStrides); + Value tile = tensor::ExtractSliceOp::create( + rewriter, loc, input, extractOffsets, extractShape, allOneStrides); auto inputMap = AffineMap::getMultiDimIdentityMap(numDims, context); // Setup the affine map function to resverse the tile along the horizontal @@ -257,20 +258,20 @@ class ConvertAtenReflectionPad1dOp if (padPosition < PAD_CENTER) { inputMap = reflectDim(inputMap, numDims, lastDim, padInts[padPosition]); // Take reflected slice as per inputMap - tile = rewriter - .create( - loc, llvm::cast(tile.getType()), tile, - tile, ArrayRef({inputMap, idMap}), iteratorTypes, - [](OpBuilder &b, Location nestedLoc, ValueRange args) { - b.create(nestedLoc, args[0]); - }) + tile = linalg::GenericOp::create( + rewriter, loc, llvm::cast(tile.getType()), + tile, tile, ArrayRef({inputMap, idMap}), iteratorTypes, + [](OpBuilder &b, Location nestedLoc, ValueRange args) { + linalg::YieldOp::create(b, nestedLoc, args[0]); + }) .getResult(0); } // Insert the tile in the resultTensor SmallVector insertOffsets(numDims, zero); insertOffsets[lastDim] = insertOffset[padPosition]; - resultTensor = rewriter.create( - loc, tile, resultTensor, insertOffsets, extractShape, allOneStrides); + resultTensor = tensor::InsertSliceOp::create(rewriter, loc, tile, + resultTensor, insertOffsets, + extractShape, allOneStrides); }; if (padInts[PAD_LEFT] > 0) @@ -343,7 +344,7 @@ class ConvertAtenReflectionPad2dOp Location loc = op.getLoc(); // Some generic helper functions for creating arithmetic operations. auto createAdd = [&](Value x, Value y) { - return rewriter.create(loc, x, y); + return arith::AddIOp::create(rewriter, loc, x, y); }; auto createAdds = [&](std::initializer_list values) { @@ -353,7 +354,7 @@ class ConvertAtenReflectionPad2dOp }; auto createSub = [&](Value x, Value y) { - return rewriter.create(loc, x, y); + return arith::SubIOp::create(rewriter, loc, x, y); }; auto createSubs = [&](std::initializer_list values) { @@ -406,13 +407,14 @@ class ConvertAtenReflectionPad2dOp auto verifyPadding = [&](int64_t padArgument, int64_t dim, StringRef errorMessage) { - auto padValue = rewriter.create(loc, padArgument); - Value index = rewriter.create(loc, dim); - Value shapeDim = rewriter.create(loc, input, index); - Value cmpPred = rewriter.create( - loc, arith::CmpIPredicate::sle, padValue, shapeDim); - rewriter.create(loc, cmpPred, - rewriter.getStringAttr(errorMessage)); + auto padValue = + arith::ConstantIndexOp::create(rewriter, loc, padArgument); + Value index = arith::ConstantIndexOp::create(rewriter, loc, dim); + Value shapeDim = tensor::DimOp::create(rewriter, loc, input, index); + Value cmpPred = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sle, padValue, shapeDim); + cf::AssertOp::create(rewriter, loc, cmpPred, + rewriter.getStringAttr(errorMessage)); }; verifyPadding(getHPadArgument(LEFT), hDim, "Left padding too large"); @@ -524,43 +526,41 @@ class ConvertAtenReflectionPad2dOp extractOffsets[hDim] = extractHOffset[horizontalPos]; extractOffsets[vDim] = extractVOffset[verticalPos]; - Value tile = rewriter.create( - loc, input, extractOffsets, extractShape, allOneStrides); + Value tile = tensor::ExtractSliceOp::create( + rewriter, loc, input, extractOffsets, extractShape, allOneStrides); auto inputMap = AffineMap::getMultiDimIdentityMap(numDims, context); tile = - rewriter - .create( - loc, llvm::cast(tile.getType()), tile, tile, - ArrayRef({inputMap, idMap}), iteratorTypes, - [&](OpBuilder &b, Location nestedLoc, ValueRange args) { - // Use linalg.index to reflect the dims - SmallVector extractIndices(numDims); - for (unsigned i = 0; i < numDims; i++) - extractIndices[i] = - b.create(nestedLoc, i); - - auto reflectDim = [&](int64_t padSize, Value dim) { - Value reflectDimSize = getConstant( - rewriter, loc, padSize - 1, rewriter.getIndexType()); - return b.create(loc, reflectDimSize, dim); - }; - - // Reverse the tile along the horizontal, vertical, or both - // dimensions. - if (shouldHReflect(horizontalPos)) - extractIndices[hDim] = reflectDim( - getHPadArgument(horizontalPos), extractIndices[hDim]); - - if (shouldVReflect(verticalPos)) - extractIndices[vDim] = reflectDim( - getVPadArgument(verticalPos), extractIndices[vDim]); - - Value extractValue = rewriter.create( - nestedLoc, tile, extractIndices); - b.create(nestedLoc, extractValue); - }) + linalg::GenericOp::create( + rewriter, loc, llvm::cast(tile.getType()), tile, + tile, ArrayRef({inputMap, idMap}), iteratorTypes, + [&](OpBuilder &b, Location nestedLoc, ValueRange args) { + // Use linalg.index to reflect the dims + SmallVector extractIndices(numDims); + for (unsigned i = 0; i < numDims; i++) + extractIndices[i] = linalg::IndexOp::create(b, nestedLoc, i); + + auto reflectDim = [&](int64_t padSize, Value dim) { + Value reflectDimSize = getConstant(rewriter, loc, padSize - 1, + rewriter.getIndexType()); + return arith::SubIOp::create(b, loc, reflectDimSize, dim); + }; + + // Reverse the tile along the horizontal, vertical, or both + // dimensions. + if (shouldHReflect(horizontalPos)) + extractIndices[hDim] = reflectDim( + getHPadArgument(horizontalPos), extractIndices[hDim]); + + if (shouldVReflect(verticalPos)) + extractIndices[vDim] = reflectDim( + getVPadArgument(verticalPos), extractIndices[vDim]); + + Value extractValue = tensor::ExtractOp::create( + rewriter, nestedLoc, tile, extractIndices); + linalg::YieldOp::create(b, nestedLoc, extractValue); + }) .getResult(0); // Insert the tile in the resultTensor. @@ -568,8 +568,9 @@ class ConvertAtenReflectionPad2dOp insertOffsets[hDim] = insertHOffset[horizontalPos]; insertOffsets[vDim] = insertVOffset[verticalPos]; - resultTensor = rewriter.create( - loc, tile, resultTensor, insertOffsets, extractShape, allOneStrides); + resultTensor = tensor::InsertSliceOp::create(rewriter, loc, tile, + resultTensor, insertOffsets, + extractShape, allOneStrides); }; for (auto v : {TOP, BOTTOM, VCENTER}) @@ -638,8 +639,8 @@ class ConvertAtenFlattenUsingIntsOp if (i < startDim || i >= endDim) j++; } - Value collapsedTensor = rewriter.create( - op->getLoc(), adaptor.getSelf(), reassociation); + Value collapsedTensor = tensor::CollapseShapeOp::create( + rewriter, op->getLoc(), adaptor.getSelf(), reassociation); rewriter.replaceOpWithNewOp(op, resultType, collapsedTensor); return success(); @@ -717,9 +718,8 @@ class ConvertAtenUnflattenIntOp for (int i = dimInt + numSizes; i < outputRank; ++i) reassociations[i - numSizes + 1].push_back(i); } - expand = rewriter - .create( - loc, expandTy, adaptor.getSelf(), reassociations) + expand = tensor::ExpandShapeOp::create(rewriter, loc, expandTy, + adaptor.getSelf(), reassociations) .getResult(); } else { reassocSizes = getTypeConvertedValues(rewriter, loc, getTypeConverter(), @@ -739,10 +739,9 @@ class ConvertAtenUnflattenIntOp RankedTensorType shapeType = RankedTensorType::get( ArrayRef{outputRank}, rewriter.getIntegerType(64)); Value shapeValue = - rewriter.create(loc, shapeType, outputShape); - expand = rewriter - .create(loc, expandTy, adaptor.getSelf(), - shapeValue) + tensor::FromElementsOp::create(rewriter, loc, shapeType, outputShape); + expand = tensor::ReshapeOp::create(rewriter, loc, expandTy, + adaptor.getSelf(), shapeValue) .getResult(); } rewriter.replaceOp(op, expand); @@ -1014,8 +1013,9 @@ class ConvertAtenViewOp : public OpConversionPattern { llvm::SmallVector outshape(resultRank, 1); auto expandTy = RankedTensorType::get(outshape, resultType.getElementType()); - Value expand = rewriter.create( - op.getLoc(), expandTy, input, ArrayRef()); + Value expand = + tensor::ExpandShapeOp::create(rewriter, op.getLoc(), expandTy, input, + ArrayRef()); rewriter.replaceOpWithNewOp(op, resultType, expand); return success(); } @@ -1299,9 +1299,8 @@ class ConvertAtenViewOp : public OpConversionPattern { resultType.getElementType()); expandedInput = - rewriter - .create(loc, intermediateResultType, - castedInput, inputAssociations) + tensor::CollapseShapeOp::create(rewriter, loc, intermediateResultType, + castedInput, inputAssociations) .getResult(); } @@ -1309,13 +1308,12 @@ class ConvertAtenViewOp : public OpConversionPattern { return indices.size() > 1; })) { - collapsedInput = rewriter - .create( - loc, adjustedResultType, - expandedInput.has_value() ? expandedInput.value() - : castedInput, - outputAssociations) - .getResult(); + collapsedInput = + tensor::ExpandShapeOp::create( + rewriter, loc, adjustedResultType, + expandedInput.has_value() ? expandedInput.value() : castedInput, + outputAssociations) + .getResult(); } Value result = collapsedInput.has_value() ? collapsedInput.value() @@ -1354,23 +1352,23 @@ class ConvertAtenViewOpToReshape : public OpConversionPattern { // the inferred dimensions sizes. auto sizeTy = cast(typeConverter->convertType(sizes.front().getType())); - Value one = - b.create(sizeTy, rewriter.getIntegerAttr(sizeTy, 1)); - Value zero = - b.create(sizeTy, rewriter.getIntegerAttr(sizeTy, 0)); + Value one = arith::ConstantOp::create(b, sizeTy, + rewriter.getIntegerAttr(sizeTy, 1)); + Value zero = arith::ConstantOp::create(b, sizeTy, + rewriter.getIntegerAttr(sizeTy, 0)); Value count = zero; Value knownSize = one; for (auto &size : sizes) { Value convert = typeConverter->materializeTargetConversion(rewriter, loc, sizeTy, size); - Value mul = b.create(knownSize, convert); - Value add = b.create(count, one); + Value mul = arith::MulIOp::create(b, knownSize, convert); + Value add = arith::AddIOp::create(b, count, one); Value isNeg = - b.create(arith::CmpIPredicate::slt, convert, zero); + arith::CmpIOp::create(b, arith::CmpIPredicate::slt, convert, zero); - knownSize = b.create(isNeg, knownSize, mul); - count = b.create(isNeg, add, count); + knownSize = arith::SelectOp::create(b, isNeg, knownSize, mul); + count = arith::SelectOp::create(b, isNeg, add, count); size = convert; } @@ -1378,9 +1376,9 @@ class ConvertAtenViewOpToReshape : public OpConversionPattern { // strict mode, there will only ever statically be one inferred dim. if (!isAssumingStrictSymbolicShapes(rewriter)) { Value countPred = - b.create(arith::CmpIPredicate::sle, count, one); - b.create( - loc, countPred, + arith::CmpIOp::create(b, arith::CmpIPredicate::sle, count, one); + cf::AssertOp::create( + b, loc, countPred, b.getStringAttr( "must have at most one inferred (negative) dimension")); } @@ -1390,21 +1388,21 @@ class ConvertAtenViewOpToReshape : public OpConversionPattern { auto selfTy = cast(self.getType()); Value totalSize = one; for (int i = 0, s = selfTy.getRank(); i < s; ++i) { - Value index = b.create(i); - Value dim = b.create(self, index); - dim = b.create(sizeTy, dim); - totalSize = b.create(totalSize, dim); + Value index = arith::ConstantIndexOp::create(b, i); + Value dim = tensor::DimOp::create(b, self, index); + dim = arith::IndexCastOp::create(b, sizeTy, dim); + totalSize = arith::MulIOp::create(b, totalSize, dim); } - Value inferredSize = b.create(totalSize, knownSize); + Value inferredSize = arith::DivSIOp::create(b, totalSize, knownSize); for (auto &size : sizes) { Value isNeg = - b.create(arith::CmpIPredicate::slt, size, zero); - size = b.create(isNeg, inferredSize, size); + arith::CmpIOp::create(b, arith::CmpIPredicate::slt, size, zero); + size = arith::SelectOp::create(b, isNeg, inferredSize, size); } auto ty = RankedTensorType::get(sizes.size(), sizes.front().getType()); - auto outputDims = b.create(ty, sizes); + auto outputDims = tensor::FromElementsOp::create(b, ty, sizes); auto resultType = cast(typeConverter->convertType(op.getType())); @@ -1498,12 +1496,12 @@ class ConvertAtenViewOpStrict : public OpConversionPattern { // Flatten to 1D. ValueTensorType flatType = rewriter.getType( ArrayRef{flatDim}, selfTy.getOptionalDtype()); - Value dimStart = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - Value dimEnd = rewriter.create( - loc, rewriter.getI64IntegerAttr(selfSizes.size() - 1)); - Value flatSelf = rewriter.create( - loc, flatType, op.getSelf(), dimStart, dimEnd); + Value dimStart = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(0)); + Value dimEnd = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(selfSizes.size() - 1)); + Value flatSelf = Torch::AtenFlattenUsingIntsOp::create( + rewriter, loc, flatType, op.getSelf(), dimStart, dimEnd); // Unflatten to requested size. rewriter.replaceOpWithNewOp( @@ -1520,8 +1518,8 @@ class ConvertAtenViewOpStrict : public OpConversionPattern { if (inferredDim >= 0) { // Inferred dim. If the above flatten/unflatten logic ever catches // everything, this branch can go away entirely. - Value one = rewriter.create( - loc, sizeTy, rewriter.getIntegerAttr(sizeTy, 1)); + Value one = arith::ConstantOp::create(rewriter, loc, sizeTy, + rewriter.getIntegerAttr(sizeTy, 1)); Value sizeProduct = one; // Multiply the non-inferred target sizes. for (int i = 0, e = sizeValues.size(); i < e; ++i) { @@ -1532,20 +1530,20 @@ class ConvertAtenViewOpStrict : public OpConversionPattern { rewriter, loc, sizeTy, size); assert(convertedSize && "Type converter did not handle size"); sizeProduct = - rewriter.create(loc, sizeProduct, convertedSize); + arith::MulIOp::create(rewriter, loc, sizeProduct, convertedSize); } // Multiply the self tensor sizes. Value selfProduct = one; for (int i = 0, e = selfTy.getRank(); i < e; ++i) { - Value index = rewriter.create(loc, i); - Value dim = rewriter.create(loc, self, index); - dim = rewriter.create(loc, sizeTy, dim); - selfProduct = rewriter.create(loc, selfProduct, dim); + Value index = arith::ConstantIndexOp::create(rewriter, loc, i); + Value dim = tensor::DimOp::create(rewriter, loc, self, index); + dim = arith::IndexCastOp::create(rewriter, loc, sizeTy, dim); + selfProduct = arith::MulIOp::create(rewriter, loc, selfProduct, dim); } Value inferredSize = - rewriter.create(loc, selfProduct, sizeProduct); + arith::DivUIOp::create(rewriter, loc, selfProduct, sizeProduct); for (int i = 0, e = sizeValues.size(); i < e; ++i) { if (i == inferredDim) { outputDimValues.push_back(inferredSize); @@ -1565,8 +1563,8 @@ class ConvertAtenViewOpStrict : public OpConversionPattern { // Normal lowering to reshape with fully computed sizes. auto outputDimsTy = RankedTensorType::get( outputDimValues.size(), outputDimValues.front().getType()); - auto outputDims = rewriter.create(loc, outputDimsTy, - outputDimValues); + auto outputDims = tensor::FromElementsOp::create( + rewriter, loc, outputDimsTy, outputDimValues); rewriter.replaceOpWithNewOp( op, resultType, adaptor.getSelf(), outputDims); return success(); @@ -1739,18 +1737,17 @@ class ConvertAtenTransposeIntOp outputDims.push_back(getDimOp(rewriter, loc, adaptor.getSelf(), i)); std::swap(outputDims[dim0], outputDims[dim1]); - Value outVector = rewriter.create( - loc, getAsOpFoldResult(outputDims), elementType); + Value outVector = tensor::EmptyOp::create( + rewriter, loc, getAsOpFoldResult(outputDims), elementType); SmallVector permutation(inputRank); std::iota(permutation.begin(), permutation.end(), 0); permutation[dim0] = dim1; permutation[dim1] = dim0; - auto transpose = - rewriter - .create(loc, inVector, outVector, permutation) - .getResult(); + auto transpose = linalg::TransposeOp::create(rewriter, loc, inVector, + outVector, permutation) + .getResult(); rewriter.replaceOpWithNewOp(op, outType, transpose); return success(); } @@ -1827,27 +1824,27 @@ class ConvertAtenSliceTensorOp : public OpConversionPattern { Value flippedInput = torch_to_linalg::flipTensor(rewriter, loc, input, SmallVector{dim}); - Value cstDim = rewriter.create(loc, dim); - Value zero = rewriter.create(loc, 0); - Value isNegativeStride = rewriter.create( - loc, arith::CmpIPredicate::slt, strides[dim], zero); - strides[dim] = rewriter.create(loc, strides[dim]); + Value cstDim = arith::ConstantIndexOp::create(rewriter, loc, dim); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value isNegativeStride = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::slt, strides[dim], zero); + strides[dim] = math::AbsIOp::create(rewriter, loc, strides[dim]); Value resShapeMulStride = - rewriter.create(loc, resultShape[dim], strides[dim]); - Value inputDim = rewriter.create(loc, input, cstDim); + arith::MulIOp::create(rewriter, loc, resultShape[dim], strides[dim]); + Value inputDim = tensor::DimOp::create(rewriter, loc, input, cstDim); Value flippedOffset = - rewriter.create(loc, inputDim, resShapeMulStride); - offsets[dim] = rewriter.create( - loc, isNegativeStride, flippedOffset, offsets[dim]); + arith::SubIOp::create(rewriter, loc, inputDim, resShapeMulStride); + offsets[dim] = arith::SelectOp::create(rewriter, loc, isNegativeStride, + flippedOffset, offsets[dim]); - input = rewriter.create(loc, isNegativeStride, - flippedInput, input); + input = arith::SelectOp::create(rewriter, loc, isNegativeStride, + flippedInput, input); SmallVector dynShape(resultType.getRank(), ShapedType::kDynamic); auto sliceType = RankedTensorType::get( dynShape, resultType.getElementType(), resultType.getEncoding()); - Value result = rewriter.create( - loc, sliceType, input, offsets, resultShape, strides); + Value result = tensor::ExtractSliceOp::create( + rewriter, loc, sliceType, input, offsets, resultShape, strides); rewriter.replaceOpWithNewOp(op, resultType, result); return success(); @@ -2028,22 +2025,21 @@ class ConvertAtenCopyOp : public OpConversionPattern { rewriter.getContext()); SmallVector iteratorTypes( selfType.getRank(), utils::IteratorType::parallel); - Value result = rewriter - .create( - loc, - /*resultType=*/selfType, - /*inputs=*/broadcastedSrc, - /*outputs=*/self, - /*indexingMaps=*/llvm::ArrayRef({id, id}), - /*iteratorTypes=*/iteratorTypes, - [](OpBuilder &b, Location loc, ValueRange args) { - Value result = args[0]; - if (args[0].getType() != args[1].getType()) { - result = convertScalarToDtype(b, loc, args[0], - args[1].getType()); - } - b.create(loc, result); - }) + Value result = linalg::GenericOp::create( + rewriter, loc, + /*resultType=*/selfType, + /*inputs=*/broadcastedSrc, + /*outputs=*/self, + /*indexingMaps=*/llvm::ArrayRef({id, id}), + /*iteratorTypes=*/iteratorTypes, + [](OpBuilder &b, Location loc, ValueRange args) { + Value result = args[0]; + if (args[0].getType() != args[1].getType()) { + result = convertScalarToDtype(b, loc, args[0], + args[1].getType()); + } + linalg::YieldOp::create(b, loc, result); + }) ->getResult(0); Type resultType = getTypeConverter()->convertType(op.getType()); @@ -2089,10 +2085,10 @@ class ConvertAtenSliceScatterOp auto abstractSrcType = RankedTensorType::get( makeShapeLLVMCompatible(srcAbstractSizes), srcType.getElementType()); Value abstractSrc = - rewriter.create(loc, abstractSrcType, src); + tensor::CastOp::create(rewriter, loc, abstractSrcType, src); - Value result = rewriter.create( - loc, abstractSrc, input, offsets, resultShape, strides); + Value result = tensor::InsertSliceOp::create( + rewriter, loc, abstractSrc, input, offsets, resultShape, strides); rewriter.replaceOpWithNewOp(op, resultType, result); @@ -2125,12 +2121,12 @@ class ConvertAtenViewAsComplexOp auto elementType = resultType.getElementType(); SmallVector resultShape; for (int64_t i = 0; i < resultType.getRank(); i++) { - auto currentDimSize = rewriter.create(loc, input, i); + auto currentDimSize = tensor::DimOp::create(rewriter, loc, input, i); resultShape.push_back(currentDimSize); } - Value outTensor = rewriter.create( - loc, getAsOpFoldResult(resultShape), elementType); + Value outTensor = tensor::EmptyOp::create( + rewriter, loc, getAsOpFoldResult(resultShape), elementType); SmallVector outputExpr; for (unsigned i = 0; i < resultType.getRank(); i++) { @@ -2149,30 +2145,29 @@ class ConvertAtenViewAsComplexOp SmallVector iteratorTypes( resultType.getRank(), utils::IteratorType::parallel); auto complexVar = - rewriter - .create( - loc, outTensor.getType(), ValueRange{}, outTensor, indexingMaps, - iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - SmallVector indicesZero; - SmallVector indicesOne; - - for (int i = 0; i < resultType.getRank(); i++) { - indicesZero.push_back(b.create(loc, i)); - indicesOne.push_back(b.create(loc, i)); - } - - indicesZero.push_back(constantZero); - indicesOne.push_back(constantOne); - - Value realVal = - b.create(loc, input, indicesZero); - Value imagVal = - b.create(loc, input, indicesOne); - Value complexVal = b.create( - loc, elementType, realVal, imagVal); - b.create(loc, complexVal); - }) + linalg::GenericOp::create( + rewriter, loc, outTensor.getType(), ValueRange{}, outTensor, + indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + SmallVector indicesZero; + SmallVector indicesOne; + + for (int i = 0; i < resultType.getRank(); i++) { + indicesZero.push_back(linalg::IndexOp::create(b, loc, i)); + indicesOne.push_back(linalg::IndexOp::create(b, loc, i)); + } + + indicesZero.push_back(constantZero); + indicesOne.push_back(constantOne); + + Value realVal = + tensor::ExtractOp::create(b, loc, input, indicesZero); + Value imagVal = + tensor::ExtractOp::create(b, loc, input, indicesOne); + Value complexVal = complex::CreateOp::create(b, loc, elementType, + realVal, imagVal); + linalg::YieldOp::create(b, loc, complexVal); + }) .getResult(0); rewriter.replaceOpWithNewOp(op, resultType, complexVar); return success(); @@ -2214,7 +2209,7 @@ class ConvertAtenViewAsRealOp : public OpConversionPattern { rewriter.createOrFold(loc, 2)); Value outTensor = - rewriter.create(loc, resultShape, elementType); + tensor::EmptyOp::create(rewriter, loc, resultShape, elementType); SmallVector inputExpr; for (unsigned i = 0; i < resultType.getRank() - 1; i++) { @@ -2237,24 +2232,23 @@ class ConvertAtenViewAsRealOp : public OpConversionPattern { Value constantZero = getConstant(rewriter, loc, 0, mlir::IndexType::get(context)); auto realVar = - rewriter - .create( - loc, outTensor.getType(), input, outTensor, indexingMaps, - iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value realVal = - b.create(loc, elementType, args[0]); - Value imagVal = - b.create(loc, elementType, args[0]); - Value lastIndex = - b.create(loc, inputType.getRank()); - Value cmpResult = b.create( - loc, arith::CmpIPredicate::eq, lastIndex, constantZero); - Value yieldValue = b.create( - loc, cmpResult, realVal, imagVal); - - b.create(loc, yieldValue); - }) + linalg::GenericOp::create( + rewriter, loc, outTensor.getType(), input, outTensor, indexingMaps, + iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value realVal = + complex::ReOp::create(b, loc, elementType, args[0]); + Value imagVal = + complex::ImOp::create(b, loc, elementType, args[0]); + Value lastIndex = + linalg::IndexOp::create(b, loc, inputType.getRank()); + Value cmpResult = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::eq, lastIndex, constantZero); + Value yieldValue = + arith::SelectOp::create(b, loc, cmpResult, realVal, imagVal); + + linalg::YieldOp::create(b, loc, yieldValue); + }) .getResult(0); rewriter.replaceOpWithNewOp(op, resultType, realVar); @@ -2315,34 +2309,37 @@ class ConvertAtenDiagonalOp : public OpConversionPattern { // compute the length of the diagonal with possible offset // if the offset is very large or very small, diagSize=0 and an empty tensor // is returned - Value indexZero = rewriter.create(loc, 0); - Value indexMinusOne = rewriter.create(loc, -1); - Value indexOffset = rewriter.create(loc, offset); - Value offsetIsNegative = rewriter.create( - loc, arith::CmpIPredicate::sle, indexOffset, indexZero); - Value sizeForNegativeOffset = rewriter.create( - loc, - rewriter.create( - loc, rewriter.create(loc, dim1Size, indexOffset), + Value indexZero = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value indexMinusOne = arith::ConstantIndexOp::create(rewriter, loc, -1); + Value indexOffset = arith::ConstantIndexOp::create(rewriter, loc, offset); + Value offsetIsNegative = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sle, indexOffset, indexZero); + Value sizeForNegativeOffset = arith::MaxSIOp::create( + rewriter, loc, + arith::MinSIOp::create( + rewriter, loc, + arith::AddIOp::create(rewriter, loc, dim1Size, indexOffset), dim2Size), indexZero); - Value sizeForPositiveOffset = rewriter.create( - loc, - rewriter.create( - loc, rewriter.create(loc, dim2Size, indexOffset), + Value sizeForPositiveOffset = arith::MaxSIOp::create( + rewriter, loc, + arith::MinSIOp::create( + rewriter, loc, + arith::SubIOp::create(rewriter, loc, dim2Size, indexOffset), dim1Size), indexZero); - Value diagSize = rewriter.create( - loc, offsetIsNegative, sizeForNegativeOffset, sizeForPositiveOffset); + Value diagSize = + arith::SelectOp::create(rewriter, loc, offsetIsNegative, + sizeForNegativeOffset, sizeForPositiveOffset); // depending on its sign, the offset affects only the row or column indices // of the diagonal - Value diagStart1 = rewriter.create( - loc, offsetIsNegative, - rewriter.create(loc, indexOffset, indexMinusOne), + Value diagStart1 = arith::SelectOp::create( + rewriter, loc, offsetIsNegative, + arith::MulIOp::create(rewriter, loc, indexOffset, indexMinusOne), indexZero); - Value diagStart2 = rewriter.create(loc, offsetIsNegative, - indexZero, indexOffset); + Value diagStart2 = arith::SelectOp::create(rewriter, loc, offsetIsNegative, + indexZero, indexOffset); SmallVector outputDims; for (auto i = 0; i < inputRank; i++) { @@ -2351,8 +2348,8 @@ class ConvertAtenDiagonalOp : public OpConversionPattern { } outputDims.push_back(diagSize); - Value outputMatrix = rewriter.create( - loc, getAsOpFoldResult(outputDims), elementType); + Value outputMatrix = tensor::EmptyOp::create( + rewriter, loc, getAsOpFoldResult(outputDims), elementType); SmallVector indexingMaps = { AffineMap::getMultiDimIdentityMap(outputRank, rewriter.getContext())}; @@ -2360,36 +2357,35 @@ class ConvertAtenDiagonalOp : public OpConversionPattern { outputRank, utils::IteratorType::parallel); auto diagonal = - rewriter - .create( - loc, outputMatrix.getType(), ValueRange{}, outputMatrix, - indexingMaps, iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - SmallVector diagIndices; - Value indexOnDiag = - b.create(loc, outputRank - 1); - Value dim1Index = - b.create(loc, indexOnDiag, diagStart1); - Value dim2Index = - b.create(loc, indexOnDiag, diagStart2); - - // specify at which input indices the diagonal values are - // extracted - for (int indIn = 0, indOut = 0; indIn < inputRank; indIn++) { - if (indIn == dim1) - diagIndices.push_back(dim1Index); - else if (indIn == dim2) - diagIndices.push_back(dim2Index); - else { - diagIndices.push_back( - b.create(loc, indOut)); - indOut++; - } - } - Value diagElt = b.create( - loc, elementType, inputMatrix, diagIndices); - b.create(loc, diagElt); - }) + linalg::GenericOp::create( + rewriter, loc, outputMatrix.getType(), ValueRange{}, outputMatrix, + indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + SmallVector diagIndices; + Value indexOnDiag = + linalg::IndexOp::create(b, loc, outputRank - 1); + Value dim1Index = + arith::AddIOp::create(b, loc, indexOnDiag, diagStart1); + Value dim2Index = + arith::AddIOp::create(b, loc, indexOnDiag, diagStart2); + + // specify at which input indices the diagonal values are + // extracted + for (int indIn = 0, indOut = 0; indIn < inputRank; indIn++) { + if (indIn == dim1) + diagIndices.push_back(dim1Index); + else if (indIn == dim2) + diagIndices.push_back(dim2Index); + else { + diagIndices.push_back( + linalg::IndexOp::create(b, loc, indOut)); + indOut++; + } + } + Value diagElt = tensor::ExtractOp::create( + b, loc, elementType, inputMatrix, diagIndices); + linalg::YieldOp::create(b, loc, diagElt); + }) .getResult(0); rewriter.replaceOpWithNewOp(op, outputType, diagonal); @@ -2411,12 +2407,12 @@ class ConvertAtenDiagEmbedOp : public OpConversionPattern { auto resultRank = inputRank + 1; // regardless of offset sign, output tensor is same - Value constOffset = b.create(loc, offset); - Value absOffset = b.create(loc, constOffset); + Value constOffset = arith::ConstantIndexOp::create(b, loc, offset); + Value absOffset = math::AbsIOp::create(b, loc, constOffset); // diagonal size is determined by last input dimension auto lastInputDim = getDimOp(b, loc, tensor, inputRank - 1); - Value diagDim = b.create(loc, lastInputDim, absOffset); + Value diagDim = arith::AddIOp::create(b, loc, lastInputDim, absOffset); // output shape has same dimensions as input // except for the diagonal dimensions @@ -2486,59 +2482,56 @@ class ConvertAtenDiagEmbedOp : public OpConversionPattern { SmallVector iteratorTypes( resultRank, utils::IteratorType::parallel); Value resultTensor = - rewriter - .create( - loc, zeroTensor.getType(), ValueRange{}, zeroTensor, - /*indexingMaps=*/indexingMaps, - /*iteratorTypes=*/iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value dim1Index = b.create(loc, dim1); - Value dim2Index = b.create(loc, dim2); - - // to pick right element from input, first add all dimensions - // except last one, then last will be either dim1 or dim2 - // depending upon lower or upper diagonal defined by offset - // sign - SmallVector inputIndices; - for (unsigned int i = 0; i < resultRank; i++) { - if (i != dim1 && i != dim2) { - inputIndices.push_back(b.create(loc, i)); - } - } - - // adjust output diagonal indices and last input Index based - // on offset - Value dim1IdxAdjusted; - Value dim2IdxAdjusted; - if (offset < 0) { - Value absOffset = - b.create(loc, -offset); - dim1IdxAdjusted = dim1Index; - dim2IdxAdjusted = - b.create(loc, dim2Index, absOffset); - inputIndices.push_back( - b.create(loc, dim2)); - } else { - Value constOffset = - b.create(loc, offset); - dim1IdxAdjusted = - b.create(loc, dim1Index, constOffset); - dim2IdxAdjusted = dim2Index; - inputIndices.push_back( - b.create(loc, dim1)); - } - - Value isDiagonal = - b.create(loc, arith::CmpIPredicate::eq, - dim1IdxAdjusted, dim2IdxAdjusted); - - Value inputElem = b.create( - loc, resultElemType, input, inputIndices); - - Value result = rewriter.create( - loc, isDiagonal, inputElem, args[0]); - b.create(loc, result); - }) + linalg::GenericOp::create( + rewriter, loc, zeroTensor.getType(), ValueRange{}, zeroTensor, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value dim1Index = linalg::IndexOp::create(b, loc, dim1); + Value dim2Index = linalg::IndexOp::create(b, loc, dim2); + + // to pick right element from input, first add all dimensions + // except last one, then last will be either dim1 or dim2 + // depending upon lower or upper diagonal defined by offset + // sign + SmallVector inputIndices; + for (unsigned int i = 0; i < resultRank; i++) { + if (i != dim1 && i != dim2) { + inputIndices.push_back(linalg::IndexOp::create(b, loc, i)); + } + } + + // adjust output diagonal indices and last input Index based + // on offset + Value dim1IdxAdjusted; + Value dim2IdxAdjusted; + if (offset < 0) { + Value absOffset = + arith::ConstantIndexOp::create(b, loc, -offset); + dim1IdxAdjusted = dim1Index; + dim2IdxAdjusted = + arith::AddIOp::create(b, loc, dim2Index, absOffset); + inputIndices.push_back(linalg::IndexOp::create(b, loc, dim2)); + } else { + Value constOffset = + arith::ConstantIndexOp::create(b, loc, offset); + dim1IdxAdjusted = + arith::AddIOp::create(b, loc, dim1Index, constOffset); + dim2IdxAdjusted = dim2Index; + inputIndices.push_back(linalg::IndexOp::create(b, loc, dim1)); + } + + Value isDiagonal = + arith::CmpIOp::create(b, loc, arith::CmpIPredicate::eq, + dim1IdxAdjusted, dim2IdxAdjusted); + + Value inputElem = tensor::ExtractOp::create( + b, loc, resultElemType, input, inputIndices); + + Value result = arith::SelectOp::create(rewriter, loc, isDiagonal, + inputElem, args[0]); + linalg::YieldOp::create(b, loc, result); + }) .getResult(0); RankedTensorType resultType = cast( @@ -2587,16 +2580,16 @@ class ConvertAtenUnfoldOp : public OpConversionPattern { if (size == 0) { RankedTensorType resultType = RankedTensorType::get({0}, selfType.getElementType()); - Value emptyTensor = rewriter.create( - loc, resultType.getShape(), resultType.getElementType()); + Value emptyTensor = tensor::EmptyOp::create( + rewriter, loc, resultType.getShape(), resultType.getElementType()); rewriter.replaceOp(op, emptyTensor); return success(); } - Value unsqueezedSelf = rewriter.create( - loc, RankedTensorType::get({1}, selfType.getElementType()), self, - ArrayRef{}); + Value unsqueezedSelf = tensor::ExpandShapeOp::create( + rewriter, loc, RankedTensorType::get({1}, selfType.getElementType()), + self, ArrayRef{}); rewriter.replaceOp(op, unsqueezedSelf); return success(); } @@ -2610,13 +2603,13 @@ class ConvertAtenUnfoldOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "dimension out of range"); } - Value dimSize = rewriter.create(loc, self, dimension); + Value dimSize = tensor::DimOp::create(rewriter, loc, self, dimension); - Value sizeValue = rewriter.create(loc, size); - Value sizeCheck = rewriter.create( - loc, arith::CmpIPredicate::ule, sizeValue, dimSize); - rewriter.create( - loc, sizeCheck, + Value sizeValue = arith::ConstantIndexOp::create(rewriter, loc, size); + Value sizeCheck = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::ule, sizeValue, dimSize); + cf::AssertOp::create( + rewriter, loc, sizeCheck, rewriter.getStringAttr("size must be <= target dimension")); /* Calculate output shape of unfold op: @@ -2631,7 +2624,7 @@ class ConvertAtenUnfoldOp : public OpConversionPattern { rewriter, loc, shape[dimension], dimSize, size, step)); } else if (shape[i] == ShapedType::kDynamic) { outputShape.push_back( - OpFoldResult(rewriter.create(loc, self, i))); + OpFoldResult(tensor::DimOp::create(rewriter, loc, self, i))); } else { outputShape.push_back(rewriter.getIndexAttr(shape[i])); } @@ -2639,8 +2632,8 @@ class ConvertAtenUnfoldOp : public OpConversionPattern { outputShape.push_back(rewriter.getIndexAttr(size)); // Empty tensor to insert values into - Value outputTensor = rewriter.create( - loc, outputShape, selfType.getElementType()); + Value outputTensor = tensor::EmptyOp::create(rewriter, loc, outputShape, + selfType.getElementType()); /** * Use reindexing to map output indices to input indices @@ -2674,13 +2667,12 @@ class ConvertAtenUnfoldOp : public OpConversionPattern { outputRank, utils::IteratorType::parallel); Value result = - rewriter - .create( - loc, outputTensor.getType(), self, outputTensor, - ArrayRef({inputAffineMap, outputAffineMap}), iteratorTypes, - [](OpBuilder &b, Location nestedLoc, ValueRange args) { - b.create(nestedLoc, args[0]); - }) + linalg::GenericOp::create( + rewriter, loc, outputTensor.getType(), self, outputTensor, + ArrayRef({inputAffineMap, outputAffineMap}), iteratorTypes, + [](OpBuilder &b, Location nestedLoc, ValueRange args) { + linalg::YieldOp::create(b, nestedLoc, args[0]); + }) .getResult(0); rewriter.replaceOp(op, result); @@ -2695,13 +2687,15 @@ class ConvertAtenUnfoldOp : public OpConversionPattern { * numBlocks = (shape[dimension] - size) // step + 1 */ if (shapeDim == ShapedType::kDynamic) { - Value numBlocksSubOp = rewriter.create( - loc, dimSize, rewriter.create(loc, size)); - Value numBlocksDivOp = rewriter.create( - loc, numBlocksSubOp, - rewriter.create(loc, step)); - Value numBlocks = rewriter.create( - loc, rewriter.create(loc, 1), numBlocksDivOp); + Value numBlocksSubOp = arith::SubIOp::create( + rewriter, loc, dimSize, + arith::ConstantIndexOp::create(rewriter, loc, size)); + Value numBlocksDivOp = arith::DivUIOp::create( + rewriter, loc, numBlocksSubOp, + arith::ConstantIndexOp::create(rewriter, loc, step)); + Value numBlocks = arith::AddIOp::create( + rewriter, loc, arith::ConstantIndexOp::create(rewriter, loc, 1), + numBlocksDivOp); return OpFoldResult(numBlocks); } diff --git a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp index 07e4b23a167d..dd792a79aaf8 100644 --- a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp @@ -38,27 +38,27 @@ static void createLinalgPayloadCalculationForGatherOps( // related to the correct dimension of the output for dimension larger // than the given `dim`. int64_t inputDimOffset = i < dim ? 0 : outputRank - inputRank; - indices.push_back(b.create(loc, i + inputDimOffset)); + indices.push_back(linalg::IndexOp::create(b, loc, i + inputDimOffset)); } } // Assert index < input.sizes[dim] - Value indexLTInputDim = b.create( - loc, arith::CmpIPredicate::slt, castIntToIndex(b, loc, index), + Value indexLTInputDim = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::slt, castIntToIndex(b, loc, index), getDimOp(b, loc, input, dim)); - b.create( - loc, indexLTInputDim, - b.getStringAttr("index must be smaller than dim size")); + cf::AssertOp::create(b, loc, indexLTInputDim, + b.getStringAttr("index must be smaller than dim size")); // Assert index >= 0 - Value cst0 = b.create(loc, b.getZeroAttr(index.getType())); + Value cst0 = + arith::ConstantOp::create(b, loc, b.getZeroAttr(index.getType())); Value indexGEThanZero = - b.create(loc, arith::CmpIPredicate::sge, index, cst0); - b.create(loc, indexGEThanZero, - b.getStringAttr("index must be larger or equal to 0")); + arith::CmpIOp::create(b, loc, arith::CmpIPredicate::sge, index, cst0); + cf::AssertOp::create(b, loc, indexGEThanZero, + b.getStringAttr("index must be larger or equal to 0")); - Value extract = b.create(loc, input, indices); - b.create(loc, extract); + Value extract = tensor::ExtractOp::create(b, loc, input, indices); + linalg::YieldOp::create(b, loc, extract); } namespace { @@ -96,15 +96,14 @@ class ConvertAtenGatherOp : public OpConversionPattern { rewriter.getMultiDimIdentityMap(rank)); SmallVector iteratorTypes( rank, utils::IteratorType::parallel); - auto genericOp = rewriter - .create( - loc, result.getType(), indices, result, affineMaps, - iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - auto index = args[0]; - createLinalgPayloadCalculationForGatherOps( - b, loc, self, rank, index, dim, rank); - }) + auto genericOp = linalg::GenericOp::create( + rewriter, loc, result.getType(), indices, result, + affineMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + auto index = args[0]; + createLinalgPayloadCalculationForGatherOps( + b, loc, self, rank, index, dim, rank); + }) .getResult(0); rewriter.replaceOpWithNewOp(op, newResultTy, genericOp); return success(); @@ -151,19 +150,18 @@ class ConvertAtenEmbeddingOp : public OpConversionPattern { }; SmallVector iteratorTypes( sizes.size(), utils::IteratorType::parallel); - Value initTensor = - rewriter.create(loc, getAsOpFoldResult(sizes), elemTy); + Value initTensor = tensor::EmptyOp::create( + rewriter, loc, getAsOpFoldResult(sizes), elemTy); Value embeddingResult = - rewriter - .create( - loc, initTensor.getType(), indices, initTensor, - /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value index = args[0]; - createLinalgPayloadCalculationForGatherOps( - b, loc, weight, weightTy.getRank(), index, /*dim=*/0, - resultRank); - }) + linalg::GenericOp::create( + rewriter, loc, initTensor.getType(), indices, initTensor, + /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value index = args[0]; + createLinalgPayloadCalculationForGatherOps( + b, loc, weight, weightTy.getRank(), index, /*dim=*/0, + resultRank); + }) .getResult(0); rewriter.replaceOpWithNewOp(op, newResultType, embeddingResult); @@ -310,87 +308,84 @@ class ConvertAtenEmbeddingBagPaddingIdxOp } Value embeddingBagResult = - rewriter - .create( - loc, initTensor.getType(), ValueRange{indices, offsets}, - initTensor, - /*indexingMaps=*/indexingMaps, - /*iteratorTypes=*/iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value indexInIndices = args[0]; - Value offsetsI = args[1]; - Value initTensorElem = args[2]; - - Value indexI = b.create(loc, /*value=*/0); - Value indexIToInt = castIndexToInt64(b, loc, indexI); - Value one = - getConstant(b, loc, 1, - mlir::IntegerType::get( - getContext(), 64, IntegerType::Signless)); - Value offsetIndexPlusOneInt = - b.create(loc, indexIToInt, one); - - Value offsetIndexPlusOne = - castIntToIndex(b, loc, offsetIndexPlusOneInt); - Value checkLast = b.create( - loc, arith::CmpIPredicate::eq, - castIndexToInt64(b, loc, offsetsLength), - offsetIndexPlusOneInt); - Value nextOffset = b.create( - loc, checkLast, castIndexToInt64(b, loc, indicesLength), - b.create(loc, offsets, - offsetIndexPlusOne)); - - Value indicesIndex = castIndexToInt64( - b, loc, b.create(loc, /*value=*/1)); - - Value offsetLessThanIndicesIndex = b.create( - loc, arith::CmpIPredicate::slt, offsetsI, indicesIndex); - Value offsetEqualToIndicesIndex = b.create( - loc, arith::CmpIPredicate::eq, offsetsI, indicesIndex); - Value offsetLessThanOrEqualToIndicesIndex = - b.create(loc, offsetLessThanIndicesIndex, - offsetEqualToIndicesIndex); - - Value indicesIndexLessThanNextOffset = - b.create(loc, arith::CmpIPredicate::slt, - indicesIndex, nextOffset); - - Value indicesIndexWithinBounds = b.create( - loc, offsetLessThanOrEqualToIndicesIndex, - indicesIndexLessThanNextOffset); - - SmallVector indexIntoWeight; - indexIntoWeight.push_back( - castIntToIndex(b, loc, indexInIndices)); - indexIntoWeight.push_back( - b.create(loc, /*value=*/2)); - Value weightElem = - b.create(loc, weight, indexIntoWeight); - - Value addResult = - b.create(loc, weightElem, initTensorElem); - Value select = b.create( - loc, indicesIndexWithinBounds, addResult, initTensorElem); - b.create(loc, select); - }) + linalg::GenericOp::create( + rewriter, loc, initTensor.getType(), ValueRange{indices, offsets}, + initTensor, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value indexInIndices = args[0]; + Value offsetsI = args[1]; + Value initTensorElem = args[2]; + + Value indexI = linalg::IndexOp::create(b, loc, /*value=*/0); + Value indexIToInt = castIndexToInt64(b, loc, indexI); + Value one = + getConstant(b, loc, 1, + mlir::IntegerType::get(getContext(), 64, + IntegerType::Signless)); + Value offsetIndexPlusOneInt = + arith::AddIOp::create(b, loc, indexIToInt, one); + + Value offsetIndexPlusOne = + castIntToIndex(b, loc, offsetIndexPlusOneInt); + Value checkLast = + arith::CmpIOp::create(b, loc, arith::CmpIPredicate::eq, + castIndexToInt64(b, loc, offsetsLength), + offsetIndexPlusOneInt); + Value nextOffset = arith::SelectOp::create( + b, loc, checkLast, castIndexToInt64(b, loc, indicesLength), + tensor::ExtractOp::create(b, loc, offsets, + offsetIndexPlusOne)); + + Value indicesIndex = castIndexToInt64( + b, loc, linalg::IndexOp::create(b, loc, /*value=*/1)); + + Value offsetLessThanIndicesIndex = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::slt, offsetsI, indicesIndex); + Value offsetEqualToIndicesIndex = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::eq, offsetsI, indicesIndex); + Value offsetLessThanOrEqualToIndicesIndex = + arith::OrIOp::create(b, loc, offsetLessThanIndicesIndex, + offsetEqualToIndicesIndex); + + Value indicesIndexLessThanNextOffset = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::slt, indicesIndex, nextOffset); + + Value indicesIndexWithinBounds = arith::AndIOp::create( + b, loc, offsetLessThanOrEqualToIndicesIndex, + indicesIndexLessThanNextOffset); + + SmallVector indexIntoWeight; + indexIntoWeight.push_back(castIntToIndex(b, loc, indexInIndices)); + indexIntoWeight.push_back( + linalg::IndexOp::create(b, loc, /*value=*/2)); + Value weightElem = + tensor::ExtractOp::create(b, loc, weight, indexIntoWeight); + + Value addResult = + arith::AddFOp::create(b, loc, weightElem, initTensorElem); + Value select = arith::SelectOp::create( + b, loc, indicesIndexWithinBounds, addResult, initTensorElem); + linalg::YieldOp::create(b, loc, select); + }) .getResult(0); // cast outputType. auto restulType0 = typeConverter->convertType(op->getResult(0).getType()); Value castedEmbeddingBagResult = - rewriter.create(loc, restulType0, embeddingBagResult); + tensor::CastOp::create(rewriter, loc, restulType0, embeddingBagResult); // offset2 tensor, this should be an empty tensor for the sum mode SmallVector offsetResultSize; Type offsetElemTy = offsetsTy.getElementType(); - Value zeroDim = rewriter.create(loc, /*value=*/0); + Value zeroDim = arith::ConstantIndexOp::create(rewriter, loc, /*value=*/0); offsetResultSize.push_back(zeroDim); - Value offsetResult = rewriter.create( - loc, getAsOpFoldResult(offsetResultSize), offsetElemTy); + Value offsetResult = tensor::EmptyOp::create( + rewriter, loc, getAsOpFoldResult(offsetResultSize), offsetElemTy); auto resultType1 = typeConverter->convertType(op->getResult(1).getType()); Value castedOffsetResult = - rewriter.create(loc, resultType1, offsetResult); + tensor::CastOp::create(rewriter, loc, resultType1, offsetResult); SmallVector offsetSize = getTensorSizes(rewriter, loc, offsets); // bagsize, vector of size offset with zeros, I think this is always just @@ -399,7 +394,7 @@ class ConvertAtenEmbeddingBagPaddingIdxOp createZeroInitTensor(rewriter, loc, offsetSize, offsetElemTy); auto resultType2 = typeConverter->convertType(op->getResult(2).getType()); Value castedBagSizeResult = - rewriter.create(loc, resultType2, bagSize); + tensor::CastOp::create(rewriter, loc, resultType2, bagSize); // max indices, vector of size offset with zeros, this is also always a // vector of zeros in the sum mode. Its mainly used in the max mode. @@ -407,7 +402,7 @@ class ConvertAtenEmbeddingBagPaddingIdxOp createZeroInitTensor(rewriter, loc, offsetSize, offsetElemTy); auto resultType3 = typeConverter->convertType(op->getResult(3).getType()); Value castedMaxIndices = - rewriter.create(loc, resultType3, indicesOut); + tensor::CastOp::create(rewriter, loc, resultType3, indicesOut); rewriter.replaceOp(op, {castedEmbeddingBagResult, castedOffsetResult, castedBagSizeResult, castedMaxIndices}); @@ -458,14 +453,14 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern { if (indicesTy.getRank() == 0) { llvm::SmallVector reassociations; indicesTy = RankedTensorType::get({1}, indicesTy.getElementType()); - indices = rewriter.create(loc, indicesTy, indices, - reassociations); + indices = tensor::ExpandShapeOp::create(rewriter, loc, indicesTy, indices, + reassociations); } SmallVector resultShape = getTensorSizes(rewriter, loc, input); resultShape[dimInt] = getTensorSizes(rewriter, loc, indices)[0]; - Value initTensor = rewriter.create( - loc, getAsOpFoldResult(resultShape), elementType); + Value initTensor = tensor::EmptyOp::create( + rewriter, loc, getAsOpFoldResult(resultShape), elementType); SmallVector resultExpr; AffineExpr indicesExpr = rewriter.getAffineDimExpr(dimInt); @@ -480,22 +475,22 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern { rewriter.getContext()); Value finalRes = - rewriter - .create( - loc, initTensor.getType(), ValueRange{indices}, initTensor, - /*indexingMaps=*/indexingMaps, - /*iteratorTypes=*/iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value index = rewriter.create( - loc, rewriter.getIndexType(), args[0]); - SmallVector indexTarget; - for (unsigned i = 0; i < inputRank; i++) - indexTarget.push_back(b.create(loc, i)); - indexTarget[dimInt] = index; - Value extractedElement = - b.create(loc, input, indexTarget); - b.create(loc, extractedElement); - }) + linalg::GenericOp::create( + rewriter, loc, initTensor.getType(), ValueRange{indices}, + initTensor, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value index = arith::IndexCastOp::create( + rewriter, loc, rewriter.getIndexType(), args[0]); + SmallVector indexTarget; + for (unsigned i = 0; i < inputRank; i++) + indexTarget.push_back(linalg::IndexOp::create(b, loc, i)); + indexTarget[dimInt] = index; + Value extractedElement = + tensor::ExtractOp::create(b, loc, input, indexTarget); + linalg::YieldOp::create(b, loc, extractedElement); + }) .getResult(0); rewriter.replaceOpWithNewOp(op, resultType, finalRes); @@ -506,13 +501,13 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern { static Value makeIndexValuePositive(OpBuilder &b, Location loc, Value index, Value input, int64_t dim) { - Value cstZero = b.create(loc, b.getI64IntegerAttr(0)); + Value cstZero = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(0)); Value isIndexNegative = - b.create(loc, arith::CmpIPredicate::slt, index, cstZero); + arith::CmpIOp::create(b, loc, arith::CmpIPredicate::slt, index, cstZero); Value inputShape = castIndexToInt64(b, loc, getDimOp(b, loc, input, dim)); - Value toPositiveIndex = b.create(loc, index, inputShape); - return b.create(loc, isIndexNegative, toPositiveIndex, - index); + Value toPositiveIndex = arith::AddIOp::create(b, loc, index, inputShape); + return arith::SelectOp::create(b, loc, isIndexNegative, toPositiveIndex, + index); } // IndexTensor for multiple input tensors broadcasts their shapes to a common @@ -627,11 +622,11 @@ class ConvertAtenIndexTensorHackedTwinOp if (staticDimSize > 1) { Value cstStaticDimSize = getConstant(rewriter, loc, staticDimSize, rewriter.getIndexType()); - auto equalToRunning = rewriter.create( - loc, arith::CmpIPredicate::eq, cstStaticDimSize, - dynamicDims[0]); - rewriter.create(loc, equalToRunning, - "mismatched size for broadcast"); + auto equalToRunning = + arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, + cstStaticDimSize, dynamicDims[0]); + cf::AssertOp::create(rewriter, loc, equalToRunning, + "mismatched size for broadcast"); } } broadcastedIndexShape.push_back(dynamicDims[0]); @@ -677,8 +672,8 @@ class ConvertAtenIndexTensorHackedTwinOp // safely map all size 1 dims to 0 in the corresponding affine maps. // TODO: For dynamic shapes, we have to either broadcast the index tensors // to a common shape or introduce some form of control flow. - Value initTensor = rewriter.create( - loc, getAsOpFoldResult(resultShape), elementType); + Value initTensor = tensor::EmptyOp::create( + rewriter, loc, getAsOpFoldResult(resultShape), elementType); SmallVector indexingMaps; for (auto indexTensor : indexTensors) { @@ -711,49 +706,48 @@ class ConvertAtenIndexTensorHackedTwinOp AffineMap::get(resultRank, 0, resultExpr, op->getContext())); Value finalRes = - rewriter - .create( - loc, initTensor.getType(), indexTensors, initTensor, - indexingMaps, iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - SmallVector extractionIndices; - if (contiguous) { - for (auto i : llvm::seq(0, firstIndexDim)) { - extractionIndices.push_back( - b.create(loc, i)); - } - for (auto i : llvm::seq(0, (int)indexTensorDims.size())) { - extractionIndices.push_back(castIntToIndex( - b, loc, - makeIndexValuePositive(b, loc, args[i], input, - extractionIndices.size()))); - } - for (auto i : - llvm::seq((int)extractionIndices.size(), inputRank)) { - extractionIndices.push_back(b.create( - loc, i + broadcastRank - replacedIndexCount)); - } - } else { - int indexCount = 0, unchanged = 0; - for (auto i : llvm::seq(0, inputRank)) { - if (indexCount < replacedIndexCount && - i == indexTensorDims[indexCount]) { - extractionIndices.push_back(castIntToIndex( - b, loc, - makeIndexValuePositive(b, loc, args[indexCount++], - input, - extractionIndices.size()))); - continue; - } - extractionIndices.push_back(b.create( - loc, broadcastRank + unchanged)); - unchanged++; - } + linalg::GenericOp::create( + rewriter, loc, initTensor.getType(), indexTensors, initTensor, + indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + SmallVector extractionIndices; + if (contiguous) { + for (auto i : llvm::seq(0, firstIndexDim)) { + extractionIndices.push_back( + linalg::IndexOp::create(b, loc, i)); + } + for (auto i : llvm::seq(0, (int)indexTensorDims.size())) { + extractionIndices.push_back(castIntToIndex( + b, loc, + makeIndexValuePositive(b, loc, args[i], input, + extractionIndices.size()))); + } + for (auto i : + llvm::seq((int)extractionIndices.size(), inputRank)) { + extractionIndices.push_back(linalg::IndexOp::create( + b, loc, i + broadcastRank - replacedIndexCount)); + } + } else { + int indexCount = 0, unchanged = 0; + for (auto i : llvm::seq(0, inputRank)) { + if (indexCount < replacedIndexCount && + i == indexTensorDims[indexCount]) { + extractionIndices.push_back(castIntToIndex( + b, loc, + makeIndexValuePositive(b, loc, args[indexCount++], + input, + extractionIndices.size()))); + continue; } - Value extractedElement = b.create( - loc, input, extractionIndices); - b.create(loc, extractedElement); - }) + extractionIndices.push_back(linalg::IndexOp::create( + b, loc, broadcastRank + unchanged)); + unchanged++; + } + } + Value extractedElement = + tensor::ExtractOp::create(b, loc, input, extractionIndices); + linalg::YieldOp::create(b, loc, extractedElement); + }) .getResult(0); rewriter.replaceOpWithNewOp(op, resultType, finalRes); @@ -769,7 +763,7 @@ static Value getScaleFactor(OpBuilder &builder, Location loc, Value dim, Value scaledDim) { Value dimInt = castIndexToInt64(builder, loc, dim); Value scaleFactorInt = - builder.create(loc, scaledDim, dimInt); + arith::CeilDivSIOp::create(builder, loc, scaledDim, dimInt); return scaleFactorInt; } @@ -822,9 +816,9 @@ class ConvertAtenUpsampleNearest2dOp if (!isa(op.getScalesH().getType())) { // Convert float values to int values. // int_value = (int64_t)ceil(float_value) - Value ceilVal = rewriter.create(loc, adaptor.getScalesH()); - Value intVal = - rewriter.create(loc, rewriter.getI64Type(), ceilVal); + Value ceilVal = math::CeilOp::create(rewriter, loc, adaptor.getScalesH()); + Value intVal = arith::FPToSIOp::create(rewriter, loc, + rewriter.getI64Type(), ceilVal); scaleFactorsInt.push_back(intVal); } else { auto scaleFactorVal = @@ -835,9 +829,9 @@ class ConvertAtenUpsampleNearest2dOp if (!isa(op.getScalesW().getType())) { // Convert float values to int values. // int_value = (int64_t)ceil(float_value) - Value ceilVal = rewriter.create(loc, adaptor.getScalesW()); - Value intVal = - rewriter.create(loc, rewriter.getI64Type(), ceilVal); + Value ceilVal = math::CeilOp::create(rewriter, loc, adaptor.getScalesW()); + Value intVal = arith::FPToSIOp::create(rewriter, loc, + rewriter.getI64Type(), ceilVal); scaleFactorsInt.push_back(intVal); } else { auto scaleFactorVal = @@ -851,33 +845,31 @@ class ConvertAtenUpsampleNearest2dOp dims[hDimOffset + 1] = castIntToIndex(rewriter, loc, outputSizeIntValues[1]); - Value outTensor = rewriter.create( - loc, getAsOpFoldResult(dims), elementType); + Value outTensor = tensor::EmptyOp::create( + rewriter, loc, getAsOpFoldResult(dims), elementType); AffineMap idMap = rewriter.getMultiDimIdentityMap(inputRank); SmallVector iteratorTypes( inputRank, utils::IteratorType::parallel); Value finalRes = - rewriter - .create( - loc, outTensor.getType(), ValueRange{}, outTensor, - /*indexingMaps=*/idMap, - /*iteratorTypes=*/iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - SmallVector indices; - for (unsigned i = 0; i < inputRank; i++) - indices.push_back(b.create(loc, i)); - - for (unsigned i = 0; i < (inputRank - hDimOffset); i++) - indices[i + hDimOffset] = b.create( - loc, indices[i + hDimOffset], - castIntToIndex(rewriter, loc, scaleFactorsInt[i])); - - Value retVal = - b.create(loc, input, indices); - b.create(loc, retVal); - }) + linalg::GenericOp::create( + rewriter, loc, outTensor.getType(), ValueRange{}, outTensor, + /*indexingMaps=*/idMap, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + SmallVector indices; + for (unsigned i = 0; i < inputRank; i++) + indices.push_back(linalg::IndexOp::create(b, loc, i)); + + for (unsigned i = 0; i < (inputRank - hDimOffset); i++) + indices[i + hDimOffset] = arith::FloorDivSIOp::create( + b, loc, indices[i + hDimOffset], + castIntToIndex(rewriter, loc, scaleFactorsInt[i])); + + Value retVal = tensor::ExtractOp::create(b, loc, input, indices); + linalg::YieldOp::create(b, loc, retVal); + }) .getResult(0); rewriter.replaceOpWithNewOp(op, resultType, finalRes); @@ -893,39 +885,43 @@ static Value getGradOutputValue(OpBuilder &builder, Location loc, Value kernelIndexH, Value kernelIndexW, SmallVector &gradOutputSizeIndexValues, SmallVector &scaleFactorsIntValues) { - Value constantOne = builder.create(loc, 1); + Value constantOne = arith::ConstantIndexOp::create(builder, loc, 1); - Value outputIndexH = builder.create( - loc, inputIndexH, castIntToIndex(builder, loc, scaleFactorsIntValues[0])); - outputIndexH = builder.create(loc, outputIndexH, kernelIndexH); + Value outputIndexH = arith::MulIOp::create( + builder, loc, inputIndexH, + castIntToIndex(builder, loc, scaleFactorsIntValues[0])); + outputIndexH = + arith::AddIOp::create(builder, loc, outputIndexH, kernelIndexH); - Value outputIndexW = builder.create( - loc, inputIndexW, castIntToIndex(builder, loc, scaleFactorsIntValues[1])); - outputIndexW = builder.create(loc, outputIndexW, kernelIndexW); + Value outputIndexW = arith::MulIOp::create( + builder, loc, inputIndexW, + castIntToIndex(builder, loc, scaleFactorsIntValues[1])); + outputIndexW = + arith::AddIOp::create(builder, loc, outputIndexW, kernelIndexW); // Handling corner cases. - Value gradOutputHMinusOne = builder.create( - loc, gradOutputSizeIndexValues[2], constantOne); - Value predH = builder.create( - loc, arith::CmpIPredicate::sle, outputIndexH, gradOutputHMinusOne); - outputIndexH = builder.create(loc, predH, outputIndexH, - gradOutputHMinusOne); - - Value gradOutputWMinusOne = builder.create( - loc, gradOutputSizeIndexValues[3], constantOne); - Value predW = builder.create( - loc, arith::CmpIPredicate::sle, outputIndexW, gradOutputWMinusOne); - outputIndexW = builder.create(loc, predW, outputIndexW, - gradOutputWMinusOne); - - Value gradOutputValue = builder.create( - loc, gradOutput, + Value gradOutputHMinusOne = arith::SubIOp::create( + builder, loc, gradOutputSizeIndexValues[2], constantOne); + Value predH = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::sle, + outputIndexH, gradOutputHMinusOne); + outputIndexH = arith::SelectOp::create(builder, loc, predH, outputIndexH, + gradOutputHMinusOne); + + Value gradOutputWMinusOne = arith::SubIOp::create( + builder, loc, gradOutputSizeIndexValues[3], constantOne); + Value predW = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::sle, + outputIndexW, gradOutputWMinusOne); + outputIndexW = arith::SelectOp::create(builder, loc, predW, outputIndexW, + gradOutputWMinusOne); + + Value gradOutputValue = tensor::ExtractOp::create( + builder, loc, gradOutput, ValueRange{numBatch, numChannel, outputIndexH, outputIndexW}); Value constantZero = - builder.create(loc, builder.getF32FloatAttr(0.0)); - Value pred = builder.create(loc, predH, predW); - Value result = builder.create( - loc, pred, gradOutputValue, + arith::ConstantOp::create(builder, loc, builder.getF32FloatAttr(0.0)); + Value pred = arith::AndIOp::create(builder, loc, predH, predW); + Value result = arith::SelectOp::create( + builder, loc, pred, gradOutputValue, convertScalarToDtype(builder, loc, constantZero, gradOutputElemType)); return result; @@ -983,8 +979,8 @@ class ConvertAtenUpsampleNearest2dBackwardOp if (!isa(op.getScalesH().getType())) { scaleFactorsFloatValues.push_back(adaptor.getScalesH()); } else { - auto scaleFactorVal = rewriter.create( - loc, + auto scaleFactorVal = arith::DivFOp::create( + rewriter, loc, convertScalarToDtype(rewriter, loc, gradOutputSizeIntValues[hDimOffset], mlir::Float32Type::get(op->getContext())), @@ -996,8 +992,8 @@ class ConvertAtenUpsampleNearest2dBackwardOp if (!isa(op.getScalesW().getType())) { scaleFactorsFloatValues.push_back(adaptor.getScalesW()); } else { - auto scaleFactorVal = rewriter.create( - loc, + auto scaleFactorVal = arith::DivFOp::create( + rewriter, loc, convertScalarToDtype(rewriter, loc, gradOutputSizeIntValues[hDimOffset + 1], mlir::Float32Type::get(op->getContext())), @@ -1010,7 +1006,7 @@ class ConvertAtenUpsampleNearest2dBackwardOp SmallVector scaleFactorsIntValues; for (auto v : scaleFactorsFloatValues) scaleFactorsIntValues.push_back(convertScalarToDtype( - rewriter, loc, rewriter.create(loc, v), + rewriter, loc, math::CeilOp::create(rewriter, loc, v), mlir::IntegerType::get(op->getContext(), 64))); Value outTensor = createZeroInitTensor( @@ -1018,11 +1014,11 @@ class ConvertAtenUpsampleNearest2dBackwardOp castIntVectorToIndexVector(rewriter, loc, inputSizeIntValues), elementType); - Value kernelTensor = rewriter.create( - loc, - getAsOpFoldResult( - castIntVectorToIndexVector(rewriter, loc, scaleFactorsIntValues)), - elementType); + Value kernelTensor = + tensor::EmptyOp::create(rewriter, loc, + getAsOpFoldResult(castIntVectorToIndexVector( + rewriter, loc, scaleFactorsIntValues)), + elementType); unsigned kernelRank = scaleFactorsIntValues.size(); SmallVector affineExprs; @@ -1048,27 +1044,26 @@ class ConvertAtenUpsampleNearest2dBackwardOp iteratorTypes.push_back(utils::IteratorType::reduction); Value finalRes = - rewriter - .create( - loc, outTensor.getType(), ValueRange{kernelTensor}, - ValueRange{outTensor}, - /*indexingMaps=*/indexingMaps, - /*iteratorTypes=*/iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value n = rewriter.create(loc, 0); - Value c = rewriter.create(loc, 1); - Value ih = rewriter.create(loc, 2); - Value iw = rewriter.create(loc, 3); - Value kh = rewriter.create(loc, 4); - Value kw = rewriter.create(loc, 5); - Value accValue = getGradOutputValue( - rewriter, loc, gradOutput, elementType, n, c, ih, iw, kh, - kw, gradOutputSizeIndexValues, scaleFactorsIntValues); - Value outputVal = args[1]; - outputVal = - rewriter.create(loc, outputVal, accValue); - b.create(loc, outputVal); - }) + linalg::GenericOp::create( + rewriter, loc, outTensor.getType(), ValueRange{kernelTensor}, + ValueRange{outTensor}, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value n = linalg::IndexOp::create(rewriter, loc, 0); + Value c = linalg::IndexOp::create(rewriter, loc, 1); + Value ih = linalg::IndexOp::create(rewriter, loc, 2); + Value iw = linalg::IndexOp::create(rewriter, loc, 3); + Value kh = linalg::IndexOp::create(rewriter, loc, 4); + Value kw = linalg::IndexOp::create(rewriter, loc, 5); + Value accValue = getGradOutputValue( + rewriter, loc, gradOutput, elementType, n, c, ih, iw, kh, kw, + gradOutputSizeIndexValues, scaleFactorsIntValues); + Value outputVal = args[1]; + outputVal = + arith::AddFOp::create(rewriter, loc, outputVal, accValue); + linalg::YieldOp::create(b, loc, outputVal); + }) ->getResult(0); rewriter.replaceOpWithNewOp(op, resultType, finalRes); diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index a3a486d14136..68947a953b7a 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -41,17 +41,17 @@ static void signShift(PatternRewriter &rewriter, Location loc, Value &arg, if (!isUnsignedType) return; int64_t minSI = -(1 << (numBits - 1)); - Value minSIValue = rewriter.create( - loc, minSI, cast(zp.getType()).getWidth()); - zp = rewriter.create(loc, zp, minSIValue); - minSIValue = rewriter.create(loc, minSI, numBits); + Value minSIValue = arith::ConstantIntOp::create( + rewriter, loc, minSI, cast(zp.getType()).getWidth()); + zp = arith::AddIOp::create(rewriter, loc, zp, minSIValue); + minSIValue = arith::ConstantIntOp::create(rewriter, loc, minSI, numBits); arg = torch_to_linalg::createElementwiseLinalgGeneric( rewriter, loc, ValueRange{arg}, cast(arg.getType()).getElementType(), [&](OpBuilder &b, Location loc, ValueRange payloadArgs) { Value result = - rewriter.create(loc, payloadArgs[0], minSIValue); - b.create(loc, result); + arith::AddIOp::create(rewriter, loc, payloadArgs[0], minSIValue); + linalg::YieldOp::create(b, loc, result); }); } @@ -64,14 +64,14 @@ static Value transposeValue(Location loc, Value value, ArrayRef perms, for (size_t i = 0; i < perms.size(); ++i) { outShape.push_back(inShape[perms[i]]); if (ShapedType::isDynamic(inShape[perms[i]])) { - dynDims.push_back(rewriter.create(loc, value, perms[i])); + dynDims.push_back(tensor::DimOp::create(rewriter, loc, value, perms[i])); } } auto outTy = RankedTensorType::get(outShape, valueTy.getElementType()); - Value empty = rewriter.create(loc, outTy, dynDims); + Value empty = tensor::EmptyOp::create(rewriter, loc, outTy, dynDims); Value transpose = - rewriter.create(loc, value, empty, perms) + linalg::TransposeOp::create(rewriter, loc, value, empty, perms) ->getResult(0); return transpose; } @@ -133,16 +133,16 @@ class ConvertAtenMmOp : public OpConversionPattern { bool isUnsigned = torch_to_linalg::isUnsignedTorchType(lhsTorchType); bool isUnsignedR = torch_to_linalg::isUnsignedTorchType(rhsTorchType); - Value lhsDim0 = rewriter.create(loc, lhs, 0); - Value rhsDim1 = rewriter.create(loc, rhs, 1); + Value lhsDim0 = tensor::DimOp::create(rewriter, loc, lhs, 0); + Value rhsDim1 = tensor::DimOp::create(rewriter, loc, rhs, 1); if (!isAssumingStrictSymbolicShapes(rewriter)) { - Value lhsDim1 = rewriter.create(loc, lhs, 1); - Value rhsDim0 = rewriter.create(loc, rhs, 0); - Value contractingDimEqual = rewriter.create( - loc, arith::CmpIPredicate::eq, lhsDim1, rhsDim0); - rewriter.create( - loc, contractingDimEqual, + Value lhsDim1 = tensor::DimOp::create(rewriter, loc, lhs, 1); + Value rhsDim0 = tensor::DimOp::create(rewriter, loc, rhs, 0); + Value contractingDimEqual = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::eq, lhsDim1, rhsDim0); + cf::AssertOp::create( + rewriter, loc, contractingDimEqual, rewriter.getStringAttr( "mismatching contracting dimension for torch.aten.mm")); } @@ -168,10 +168,10 @@ class ConvertAtenMmOp : public OpConversionPattern { rewriter, loc, getTypeConverter()->convertType(rhsZeroPoint.getType()), rhsZeroPoint); - lhsZeroPoint = rewriter.create( - loc, rewriter.getI32Type(), lhsZeroPoint); - rhsZeroPoint = rewriter.create( - loc, rewriter.getI32Type(), rhsZeroPoint); + lhsZeroPoint = arith::TruncIOp::create( + rewriter, loc, rewriter.getI32Type(), lhsZeroPoint); + rhsZeroPoint = arith::TruncIOp::create( + rewriter, loc, rewriter.getI32Type(), rhsZeroPoint); // change uint8 quantization -> int8 quantization int64_t numBits = @@ -180,21 +180,18 @@ class ConvertAtenMmOp : public OpConversionPattern { numBits = cast(rhsType.getElementType()).getWidth(); signShift(rewriter, loc, rhs, rhsZeroPoint, isUnsignedR, numBits); - matmul = - rewriter - .create( - loc, zeroFill.getType(), - ValueRange{lhs, rhs, lhsZeroPoint, rhsZeroPoint}, zeroFill) - .getResult(0); + matmul = linalg::QuantizedMatmulOp::create( + rewriter, loc, zeroFill.getType(), + ValueRange{lhs, rhs, lhsZeroPoint, rhsZeroPoint}, zeroFill) + .getResult(0); } else if (isUnsigned) { - auto matmulOp = rewriter.create( - loc, zeroFill.getType(), ValueRange{lhs, rhs}, zeroFill); + auto matmulOp = linalg::MatmulOp::create( + rewriter, loc, zeroFill.getType(), ValueRange{lhs, rhs}, zeroFill); matmulOp.setCast(linalg::TypeFn::cast_unsigned); matmul = matmulOp->getResult(0); } else { - matmul = rewriter - .create(loc, zeroFill.getType(), - ValueRange{lhs, rhs}, zeroFill) + matmul = linalg::MatmulOp::create(rewriter, loc, zeroFill.getType(), + ValueRange{lhs, rhs}, zeroFill) .getResult(0); } @@ -301,10 +298,10 @@ class ConvertAtenMatmulOp : public OpConversionPattern { rewriter, loc, getTypeConverter()->convertType(rhsZeroPoint.getType()), rhsZeroPoint); - lhsZeroPoint = rewriter.create( - loc, rewriter.getI32Type(), lhsZeroPoint); - rhsZeroPoint = rewriter.create( - loc, rewriter.getI32Type(), rhsZeroPoint); + lhsZeroPoint = arith::TruncIOp::create( + rewriter, loc, rewriter.getI32Type(), lhsZeroPoint); + rhsZeroPoint = arith::TruncIOp::create( + rewriter, loc, rewriter.getI32Type(), rhsZeroPoint); // change uint8 quantization -> int8 quantization int64_t numBits = @@ -328,16 +325,16 @@ class ConvertAtenMatmulOp : public OpConversionPattern { int64_t lhsDim = lhsType.getShape()[0]; auto lhsUnsqueezeType = RankedTensorType::get( ArrayRef{1, lhsDim}, lhsType.getElementType()); - lhs = rewriter.create(loc, lhsUnsqueezeType, - lhs, reassociation); + lhs = tensor::ExpandShapeOp::create(rewriter, loc, lhsUnsqueezeType, + lhs, reassociation); } if (rhsVec) { // unsqueeze rhs to a matrix int64_t rhsDim = rhsType.getShape()[0]; auto rhsUnsqueezeType = RankedTensorType::get( ArrayRef{rhsDim, 1}, rhsType.getElementType()); - rhs = rewriter.create(loc, rhsUnsqueezeType, - rhs, reassociation); + rhs = tensor::ExpandShapeOp::create(rewriter, loc, rhsUnsqueezeType, + rhs, reassociation); } // get quantized_matmul and squeeze result Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0); @@ -348,19 +345,18 @@ class ConvertAtenMatmulOp : public OpConversionPattern { Value zeroTensor = createZeroInitTensor( rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType); - Value matmul = rewriter - .create( - loc, zeroTensor.getType(), - ValueRange{lhs, rhs, lhsZeroPoint, rhsZeroPoint}, - zeroTensor) - .getResult(0); + Value matmul = + linalg::QuantizedMatmulOp::create( + rewriter, loc, zeroTensor.getType(), + ValueRange{lhs, rhs, lhsZeroPoint, rhsZeroPoint}, zeroTensor) + .getResult(0); int64_t resultRank = resultType.getRank(); if (resultRank == 0) { // in vec-vec case, need to collapse result to a scalar reassociation.clear(); } - matmul = rewriter.create( - loc, resultType, matmul, reassociation); + matmul = tensor::CollapseShapeOp::create(rewriter, loc, resultType, + matmul, reassociation); rewriter.replaceOpWithNewOp(op, newResultType, matmul); return success(); } @@ -379,11 +375,9 @@ class ConvertAtenMatmulOp : public OpConversionPattern { checkDimEqualHelper(rewriter, loc, lhsDim0, rhsDim0); Value zeroTensor = createZeroInitTensor(rewriter, loc, {}, elementType); - Value dotProd = - rewriter - .create(loc, zeroTensor.getType(), - ValueRange{lhs, rhs}, zeroTensor) - .getResult(0); + Value dotProd = linalg::DotOp::create(rewriter, loc, zeroTensor.getType(), + ValueRange{lhs, rhs}, zeroTensor) + .getResult(0); rewriter.replaceOpWithNewOp(op, newResultType, dotProd); return success(); } @@ -398,9 +392,8 @@ class ConvertAtenMatmulOp : public OpConversionPattern { Value zeroTensor = createZeroInitTensor(rewriter, loc, ValueRange{rhsDim1}, elementType); Value matmul = - rewriter - .create(loc, zeroTensor.getType(), - ValueRange{lhs, rhs}, zeroTensor) + linalg::VecmatOp::create(rewriter, loc, zeroTensor.getType(), + ValueRange{lhs, rhs}, zeroTensor) .getResult(0); rewriter.replaceOpWithNewOp(op, newResultType, matmul); return success(); @@ -416,9 +409,8 @@ class ConvertAtenMatmulOp : public OpConversionPattern { Value zeroTensor = createZeroInitTensor(rewriter, loc, ValueRange{lhsDim0}, elementType); Value matmul = - rewriter - .create(loc, zeroTensor.getType(), - ValueRange{lhs, rhs}, zeroTensor) + linalg::MatvecOp::create(rewriter, loc, zeroTensor.getType(), + ValueRange{lhs, rhs}, zeroTensor) .getResult(0); rewriter.replaceOpWithNewOp(op, newResultType, matmul); return success(); @@ -436,16 +428,14 @@ class ConvertAtenMatmulOp : public OpConversionPattern { rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType); Value matmul; if (lhsZeroPoint) { - matmul = rewriter - .create( - loc, zeroTensor.getType(), - ValueRange{lhs, rhs, lhsZeroPoint, rhsZeroPoint}, - zeroTensor) - .getResult(0); + matmul = + linalg::QuantizedMatmulOp::create( + rewriter, loc, zeroTensor.getType(), + ValueRange{lhs, rhs, lhsZeroPoint, rhsZeroPoint}, zeroTensor) + .getResult(0); } else { - matmul = rewriter - .create(loc, zeroTensor.getType(), - ValueRange{lhs, rhs}, zeroTensor) + matmul = linalg::MatmulOp::create(rewriter, loc, zeroTensor.getType(), + ValueRange{lhs, rhs}, zeroTensor) .getResult(0); } rewriter.replaceOpWithNewOp(op, newResultType, matmul); @@ -538,21 +528,19 @@ class ConvertAtenMatmulOp : public OpConversionPattern { elementType); Value matmul; if (lhsZeroPoint) { - matmul = rewriter - .create( - loc, zeroTensor.getType(), - ValueRange{broadcastedLhs, broadcastedRhs, - lhsZeroPoint, rhsZeroPoint}, - zeroTensor) + matmul = linalg::QuantizedBatchMatmulOp::create( + rewriter, loc, zeroTensor.getType(), + ValueRange{broadcastedLhs, broadcastedRhs, lhsZeroPoint, + rhsZeroPoint}, + zeroTensor) .getResult(0); rewriter.replaceOpWithNewOp(op, newResultType, matmul); return success(); } - matmul = rewriter - .create( - loc, zeroTensor.getType(), - ValueRange{broadcastedLhs, broadcastedRhs}, zeroTensor) + matmul = linalg::BatchMatmulOp::create( + rewriter, loc, zeroTensor.getType(), + ValueRange{broadcastedLhs, broadcastedRhs}, zeroTensor) .getResult(0); rewriter.replaceOpWithNewOp(op, newResultType, matmul); return success(); @@ -579,10 +567,10 @@ class ConvertAtenMatmulOp : public OpConversionPattern { j++; reassociation[j].push_back(i); } - Value collapsedLhs = rewriter.create( - op->getLoc(), broadcastedLhs, reassociation); - Value collapsedRhs = rewriter.create( - op->getLoc(), broadcastedRhs, reassociation); + Value collapsedLhs = tensor::CollapseShapeOp::create( + rewriter, op->getLoc(), broadcastedLhs, reassociation); + Value collapsedRhs = tensor::CollapseShapeOp::create( + rewriter, op->getLoc(), broadcastedRhs, reassociation); // Compute the result shape after collapsing the batch dimensions. SmallVector collapsedResultShape; @@ -596,32 +584,29 @@ class ConvertAtenMatmulOp : public OpConversionPattern { SmallVector updatedCollapseResultShape = getAsOpFoldResult(collapsedResultShape); - Value initTensor = rewriter.create( - loc, updatedCollapseResultShape, elementType); - Value c0 = rewriter.create( - loc, rewriter.getZeroAttr(elementType)); + Value initTensor = tensor::EmptyOp::create( + rewriter, loc, updatedCollapseResultShape, elementType); + Value c0 = arith::ConstantOp::create(rewriter, loc, + rewriter.getZeroAttr(elementType)); Value zeroTensor = - rewriter.create(loc, c0, initTensor).getResult(0); + linalg::FillOp::create(rewriter, loc, c0, initTensor).getResult(0); Value batchMatMul; if (lhsZeroPoint) { - batchMatMul = rewriter - .create( - loc, zeroTensor.getType(), - ValueRange{collapsedLhs, collapsedRhs, - lhsZeroPoint, rhsZeroPoint}, - zeroTensor) + batchMatMul = linalg::QuantizedBatchMatmulOp::create( + rewriter, loc, zeroTensor.getType(), + ValueRange{collapsedLhs, collapsedRhs, lhsZeroPoint, + rhsZeroPoint}, + zeroTensor) .getResult(0); } else { - batchMatMul = - rewriter - .create( - loc, zeroTensor.getType(), - ValueRange{collapsedLhs, collapsedRhs}, zeroTensor) - .getResult(0); + batchMatMul = linalg::BatchMatmulOp::create( + rewriter, loc, zeroTensor.getType(), + ValueRange{collapsedLhs, collapsedRhs}, zeroTensor) + .getResult(0); } - Value expandResult = rewriter.create( - loc, resultType, batchMatMul, reassociation); + Value expandResult = tensor::ExpandShapeOp::create( + rewriter, loc, resultType, batchMatMul, reassociation); rewriter.replaceOpWithNewOp(op, newResultType, expandResult); return success(); @@ -656,18 +641,17 @@ class ConvertAtenMatmulOp : public OpConversionPattern { utils::IteratorType::parallel}); Value finalRes = - rewriter - .create( - loc, zeroTensor.getType(), - ValueRange{broadcastedLhs, broadcastedRhs}, zeroTensor, - /*indexingMaps=*/indexingMaps, - /*iteratorTypes=*/iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value l = args[0], r = args[1], res = args[2]; - Value mul = b.create(loc, l, r); - Value add = b.create(loc, mul, res); - b.create(loc, add); - }) + linalg::GenericOp::create( + rewriter, loc, zeroTensor.getType(), + ValueRange{broadcastedLhs, broadcastedRhs}, zeroTensor, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value l = args[0], r = args[1], res = args[2]; + Value mul = arith::MulFOp::create(b, loc, l, r); + Value add = arith::AddFOp::create(b, loc, mul, res); + linalg::YieldOp::create(b, loc, add); + }) .getResult(0); rewriter.replaceOpWithNewOp(op, newResultType, finalRes); @@ -736,9 +720,8 @@ class ConvertAtenBmmOp : public OpConversionPattern { rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2}, accumulatorDType); Value bmm = - rewriter - .create(loc, initTensor0.getType(), - ValueRange{lhs, rhs}, initTensor0) + linalg::BatchMatmulOp::create(rewriter, loc, initTensor0.getType(), + ValueRange{lhs, rhs}, initTensor0) .getResult(0); if (accumulatorDType != resultElementType) { @@ -803,8 +786,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { inputZp = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(inputZp.getType()), inputZp); - inputZp = - rewriter.create(loc, rewriter.getI32Type(), inputZp); + inputZp = arith::TruncIOp::create(rewriter, loc, rewriter.getI32Type(), + inputZp); auto torchDtype = cast(make.getType()).getDtype(); inputUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype); } @@ -819,8 +802,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { weightZp = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(weightZp.getType()), weightZp); - weightZp = rewriter.create(loc, rewriter.getI32Type(), - weightZp); + weightZp = arith::TruncIOp::create(rewriter, loc, rewriter.getI32Type(), + weightZp); auto torchDtype = cast(make.getType()).getDtype(); weightUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype); } @@ -916,8 +899,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { } weight = unsqueezeWeightInfo.value(); - Value cstZero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); + Value cstZero = arith::ConstantOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(0)); paddingIntValues.push_back(cstZero); outputPaddingIntValues.push_back(cstZero); strideInts.push_back(1); @@ -940,12 +923,12 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { auto validate = [&](Value toValidate, std::string err) { Value c0 = - rewriter.create(loc, rewriter.getIndexAttr(0)); - Value inputValid = rewriter.create( - loc, arith::CmpIPredicate::eq, c0, - rewriter.create(loc, toValidate, groups)); - rewriter.create(loc, inputValid, - rewriter.getStringAttr(err)); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(0)); + Value inputValid = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::eq, c0, + arith::RemSIOp::create(rewriter, loc, toValidate, groups)); + cf::AssertOp::create(rewriter, loc, inputValid, + rewriter.getStringAttr(err)); }; validate(inChannels, "invalid: groups must divide input channel size evenly."); @@ -971,18 +954,18 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { Value pad = inputZp; if (!pad) { if (isa(inputDTy)) - pad = rewriter.create( - op.getLoc(), rewriter.getFloatAttr(inputDTy, 0.0)); + pad = arith::ConstantOp::create(rewriter, op.getLoc(), + rewriter.getFloatAttr(inputDTy, 0.0)); if (isa(inputDTy)) - pad = rewriter.create( - op.getLoc(), rewriter.getIntegerAttr(inputDTy, 0)); + pad = arith::ConstantOp::create(rewriter, op.getLoc(), + rewriter.getIntegerAttr(inputDTy, 0)); } if (pad.getType() != inputDTy) { if (isa(inputDTy)) - pad = rewriter.create(op.getLoc(), inputDTy, pad); + pad = arith::TruncFOp::create(rewriter, op.getLoc(), inputDTy, pad); if (isa(inputDTy)) - pad = rewriter.create(op.getLoc(), inputDTy, pad); + pad = arith::TruncIOp::create(rewriter, op.getLoc(), inputDTy, pad); } // The expandWeight lambda function below is used to expand the group @@ -1025,8 +1008,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { indices.push_back({i}); auto retType = inType.clone(makeShapeLLVMCompatible(outShape)); - return rewriter.create(loc, retType, tensor, - indices); + return tensor::ExpandShapeOp::create(rewriter, loc, retType, tensor, + indices); }; if (transposed) { @@ -1034,9 +1017,9 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { weight = isGroupedConv ? expandWeight(weight) : weight; Value c0 = - rewriter.create(loc, rewriter.getIndexAttr(0)); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(0)); Value c1 = - rewriter.create(loc, rewriter.getIndexAttr(1)); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(1)); // Transpose and flip weight SmallVector weightInitDims = getTensorSizes(rewriter, loc, weight); @@ -1045,7 +1028,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { // output dimension needs to consider the number of groups. std::iter_swap(weightInitDims.begin() + 1, weightInitDims.begin() + 2); auto numGroupsVal = - rewriter.create(loc, numGroups); + mlir::arith::ConstantIndexOp::create(rewriter, loc, numGroups); outDims[1] = rewriter.createOrFold( loc, weightInitDims[1], numGroupsVal); } else { @@ -1059,32 +1042,31 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { weightRank, utils::IteratorType::parallel); SmallVector indexingMaps{ AffineMap::getMultiDimIdentityMap(weightRank, context)}; - weight = rewriter - .create( - loc, weightInitTensor.getType(), ValueRange{}, - weightInitTensor, indexingMaps, iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - SmallVector indices; - for (size_t i = 0; i < weightRank; i++) - indices.push_back(b.create(loc, i)); - auto fcIdxSwapOffset = isGroupedConv ? 1 : 0; - std::iter_swap(indices.begin() + fcIdxSwapOffset, - indices.begin() + fcIdxSwapOffset + 1); - // Flip only the spatial dimensions (from 2 to - // weightRank) - for (size_t flipDim = fcIdxSwapOffset + 2; - flipDim < weightRank; flipDim++) { - indices[flipDim] = b.create( - loc, - b.create( - loc, weightInitDims[flipDim], c1), - indices[flipDim]); - } - Value res = - b.create(loc, weight, indices) - .getResult(); - b.create(loc, res); - }) + weight = linalg::GenericOp::create( + rewriter, loc, weightInitTensor.getType(), ValueRange{}, + weightInitTensor, indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + SmallVector indices; + for (size_t i = 0; i < weightRank; i++) + indices.push_back(linalg::IndexOp::create(b, loc, i)); + auto fcIdxSwapOffset = isGroupedConv ? 1 : 0; + std::iter_swap(indices.begin() + fcIdxSwapOffset, + indices.begin() + fcIdxSwapOffset + 1); + // Flip only the spatial dimensions (from 2 to + // weightRank) + for (size_t flipDim = fcIdxSwapOffset + 2; + flipDim < weightRank; flipDim++) { + indices[flipDim] = arith::SubIOp::create( + b, loc, + arith::SubIOp::create(b, loc, + weightInitDims[flipDim], c1), + indices[flipDim]); + } + Value res = + tensor::ExtractOp::create(b, loc, weight, indices) + .getResult(); + linalg::YieldOp::create(b, loc, res); + }) .getResult(0); paddedInput = createTransposedInputPadding( @@ -1115,8 +1097,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { } Type accumulatorDType = getDefaultAccType(rewriter, inputDTy); - Value initTensor = rewriter.create( - loc, getAsOpFoldResult(outDims), accumulatorDType); + Value initTensor = tensor::EmptyOp::create( + rewriter, loc, getAsOpFoldResult(outDims), accumulatorDType); Value outputTensor; if (accumulatorDType != resultDTy && !isa(bias.getType())) @@ -1125,14 +1107,14 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { if (isa(bias.getType())) { Value c0; if (isa(accumulatorDType)) { - c0 = rewriter.create( - loc, FloatAttr::get(accumulatorDType, 0.0)); + c0 = arith::ConstantOp::create(rewriter, loc, + FloatAttr::get(accumulatorDType, 0.0)); } else if (isa(accumulatorDType)) { - c0 = rewriter.create( - loc, IntegerAttr::get(accumulatorDType, 0)); + c0 = arith::ConstantOp::create(rewriter, loc, + IntegerAttr::get(accumulatorDType, 0)); } outputTensor = - rewriter.create(loc, c0, initTensor).getResult(0); + linalg::FillOp::create(rewriter, loc, c0, initTensor).getResult(0); } else { auto biasType = cast(bias.getType()); @@ -1146,9 +1128,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { for (int i = 0; i < resultRank; ++i) if (i != 1) addedDimensions.push_back(i); - outputTensor = rewriter - .create(loc, bias, initTensor, - addedDimensions) + outputTensor = linalg::BroadcastOp::create(rewriter, loc, bias, + initTensor, addedDimensions) ->getResult(0); } @@ -1156,14 +1137,16 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { auto dilationAttr = rewriter.getI64VectorAttr(dilationInts); Value inputStride = - rewriter.create(loc, inChannels, groups); + arith::FloorDivSIOp::create(rewriter, loc, inChannels, groups); Value weightStride = - rewriter.create(loc, weightBatch, groups); - - SmallVector zeroOffsets(inRank, rewriter.create( - loc, rewriter.getIndexAttr(0))); - SmallVector unitStrides(inRank, rewriter.create( - loc, rewriter.getIndexAttr(1))); + arith::FloorDivSIOp::create(rewriter, loc, weightBatch, groups); + + SmallVector zeroOffsets( + inRank, + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(0))); + SmallVector unitStrides( + inRank, + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(1))); SmallVector outDimSlice(outDims); outDimSlice[1] = weightStride; SmallVector inputSliceSizes{inBatch, inputStride}; @@ -1183,27 +1166,24 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { if (numGroups == 1 && !inputZp) { switch (numSpatialDims) { case 1: - conv = rewriter - .create( - loc, outputTensor.getType(), - ValueRange{paddedInput, weight}, outputTensor, - stridesAttr, dilationAttr) + conv = linalg::Conv1DNcwFcwOp::create( + rewriter, loc, outputTensor.getType(), + ValueRange{paddedInput, weight}, outputTensor, stridesAttr, + dilationAttr) .getResult(0); break; case 2: - conv = rewriter - .create( - loc, outputTensor.getType(), - ValueRange{paddedInput, weight}, outputTensor, - stridesAttr, dilationAttr) + conv = linalg::Conv2DNchwFchwOp::create( + rewriter, loc, outputTensor.getType(), + ValueRange{paddedInput, weight}, outputTensor, stridesAttr, + dilationAttr) .getResult(0); break; case 3: - conv = rewriter - .create( - loc, outputTensor.getType(), - ValueRange{paddedInput, weight}, outputTensor, - stridesAttr, dilationAttr) + conv = linalg::Conv3DNcdhwFcdhwOp::create( + rewriter, loc, outputTensor.getType(), + ValueRange{paddedInput, weight}, outputTensor, stridesAttr, + dilationAttr) .getResult(0); break; default: @@ -1224,11 +1204,10 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { if (numGroups == 1 && inputZp) { switch (numSpatialDims) { case 2: - conv = rewriter - .create( - loc, outputTensor.getType(), - ValueRange{paddedInput, weight, inputZp, weightZp}, - outputTensor, stridesAttr, dilationAttr) + conv = linalg::Conv2DNchwFchwQOp::create( + rewriter, loc, outputTensor.getType(), + ValueRange{paddedInput, weight, inputZp, weightZp}, + outputTensor, stridesAttr, dilationAttr) .getResult(0); break; case 3: { @@ -1251,11 +1230,10 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { outputTensor = transposeValue(op.getLoc(), outputTensor, inPerms, rewriter); - conv = rewriter - .create( - loc, outputTensor.getType(), - ValueRange{paddedInput, weight, inputZp, weightZp}, - outputTensor, stridesAttr, dilationAttr) + conv = linalg::Conv3DNdhwcDhwcfQOp::create( + rewriter, loc, outputTensor.getType(), + ValueRange{paddedInput, weight, inputZp, weightZp}, + outputTensor, stridesAttr, dilationAttr) .getResult(0); llvm::SmallVector outPerms; @@ -1303,24 +1281,22 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { } Type collapsedType = RankedTensorType::get( makeShapeLLVMCompatible(collapsedShape), weightDTy); - Value collapsedWeight = rewriter.create( - loc, collapsedType, weight, collapsedDims); + Value collapsedWeight = tensor::CollapseShapeOp::create( + rewriter, loc, collapsedType, weight, collapsedDims); if (!inputZp) { switch (numSpatialDims) { case 1: - conv = rewriter - .create( - loc, outputTensor.getType(), - ValueRange{paddedInput, collapsedWeight}, outputTensor, - stridesAttr, dilationAttr) + conv = linalg::DepthwiseConv1DNcwCwOp::create( + rewriter, loc, outputTensor.getType(), + ValueRange{paddedInput, collapsedWeight}, outputTensor, + stridesAttr, dilationAttr) .getResult(0); break; case 2: - conv = rewriter - .create( - loc, outputTensor.getType(), - ValueRange{paddedInput, collapsedWeight}, outputTensor, - stridesAttr, dilationAttr) + conv = linalg::DepthwiseConv2DNchwChwOp::create( + rewriter, loc, outputTensor.getType(), + ValueRange{paddedInput, collapsedWeight}, outputTensor, + stridesAttr, dilationAttr) .getResult(0); break; default: @@ -1358,13 +1334,11 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { outputTensor = transposeValue(op.getLoc(), outputTensor, inPerms, rewriter); - conv = - rewriter - .create( - loc, outputTensor.getType(), - ValueRange{paddedInput, collapsedWeight, inputZp, weightZp}, - outputTensor, stridesAttr, dilationAttr) - .getResult(0); + conv = linalg::DepthwiseConv2DNhwcHwcQOp::create( + rewriter, loc, outputTensor.getType(), + ValueRange{paddedInput, collapsedWeight, inputZp, weightZp}, + outputTensor, stridesAttr, dilationAttr) + .getResult(0); // convert output nhwc -> nchw conv = transposeValue(op.getLoc(), conv, resultPerms, rewriter); } @@ -1428,8 +1402,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { } auto retType = inType.clone(makeShapeLLVMCompatible(outShape)); - return rewriter.create(loc, retType, tensor, - indices); + return tensor::ExpandShapeOp::create(rewriter, loc, retType, tensor, + indices); }; Value paddedInputExpanded = expandGroups(paddedInput, 1); @@ -1442,21 +1416,17 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { if (numSpatialDims == 2) { // 2D grouped convolution if (!inputZp) { - conv = - rewriter - .create( - loc, expandOutputTensor.getResultType(), - ValueRange{paddedInputExpanded, weight}, - expandOutputTensor.getResult(), stridesAttr, dilationAttr) - .getResult(0); + conv = linalg::Conv2DNgchwGfchwOp::create( + rewriter, loc, expandOutputTensor.getResultType(), + ValueRange{paddedInputExpanded, weight}, + expandOutputTensor.getResult(), stridesAttr, dilationAttr) + .getResult(0); } else { - conv = - rewriter - .create( - loc, expandOutputTensor.getResultType(), - ValueRange{paddedInputExpanded, weight, inputZp, weightZp}, - expandOutputTensor.getResult(), stridesAttr, dilationAttr) - .getResult(0); + conv = linalg::Conv2DNgchwGfchwQOp::create( + rewriter, loc, expandOutputTensor.getResultType(), + ValueRange{paddedInputExpanded, weight, inputZp, weightZp}, + expandOutputTensor.getResult(), stridesAttr, dilationAttr) + .getResult(0); } } else if (numSpatialDims == 3) { // MLIR does not have a named 3D grouped convolution op, so we use @@ -1509,34 +1479,32 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { utils::IteratorType::reduction // KW }; - conv = - rewriter - .create( - loc, expandOutputTensor.getResultType(), - ValueRange{paddedInputExpanded, weight}, - expandOutputTensor.getResult(), indexingMaps, iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value input = args[0]; - Value weight = args[1]; - Value output = args[2]; - - // Convert input and weight to accumulator type if needed - Type accType = output.getType(); - if (input.getType() != accType) { - input = b.create(loc, accType, input); - } - if (weight.getType() != accType) { - weight = b.create(loc, accType, weight); - } - - Value mul = b.create(loc, input, weight); - Value add = b.create(loc, mul, output); - b.create(loc, add); - }) - .getResult(0); + conv = linalg::GenericOp::create( + rewriter, loc, expandOutputTensor.getResultType(), + ValueRange{paddedInputExpanded, weight}, + expandOutputTensor.getResult(), indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value input = args[0]; + Value weight = args[1]; + Value output = args[2]; + + // Convert input and weight to accumulator type if needed + Type accType = output.getType(); + if (input.getType() != accType) { + input = arith::ExtFOp::create(b, loc, accType, input); + } + if (weight.getType() != accType) { + weight = arith::ExtFOp::create(b, loc, accType, weight); + } + + Value mul = arith::MulFOp::create(b, loc, input, weight); + Value add = arith::AddFOp::create(b, loc, mul, output); + linalg::YieldOp::create(b, loc, add); + }) + .getResult(0); } - conv = rewriter.create( - loc, outputTensor.getType(), conv, + conv = tensor::CollapseShapeOp::create( + rewriter, loc, outputTensor.getType(), conv, expandOutputTensor.getReassociationIndices()); Type newResultType = getTypeConverter()->convertType(op.getType()); if (accumulatorDType != resultDTy) { @@ -1594,7 +1562,7 @@ Value ConvertAtenConvolutionOp::createTransposedInputPadding( SmallVector extractSliceOffsets{c0, c0}; bool anyDimensionPaddingIsNegative = false; - Value c2 = rewriter.create(loc, rewriter.getIndexAttr(2)); + Value c2 = arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(2)); for (size_t i = 0; i < numSpatialDims; i++) { Value innerSize = rewriter.createOrFold(loc, inDims[i], c1); @@ -1647,13 +1615,14 @@ Value ConvertAtenConvolutionOp::createTransposedInputPadding( auto insertSliceOpInput = input; if (anyDimensionPaddingIsNegative) { - insertSliceOpInput = rewriter.create( - loc, torch_to_linalg::removeSizeInformation(rewriter, loc, input), + insertSliceOpInput = tensor::ExtractSliceOp::create( + rewriter, loc, + torch_to_linalg::removeSizeInformation(rewriter, loc, input), extractSliceOffsets, sliceSizes, strideIndexValues); } - auto paddedInput = rewriter.create( - loc, + auto paddedInput = tensor::InsertSliceOp::create( + rewriter, loc, torch_to_linalg::removeSizeInformation(rewriter, loc, insertSliceOpInput), initTensor, insertSliceOffsets, sliceSizes, strideIndexValues); return paddedInput; @@ -1691,8 +1660,8 @@ Value getDFTMatmulCoeff(OpBuilder b, Location loc, values.push_back(std::complex(real, imag)); } } - return b.create( - loc, matrixType, DenseElementsAttr::get(matrixType, values)); + return arith::ConstantOp::create(b, loc, matrixType, + DenseElementsAttr::get(matrixType, values)); } struct ConvertAtenFftRfftOp final : OpConversionPattern { @@ -1817,22 +1786,21 @@ struct ConvertAtenFftRfftOp final : OpConversionPattern { utils::IteratorType::parallel}); Value complexRes = - rewriter - .create( - loc, zeroTensor.getType(), - /*inputs=*/ValueRange{lhs, rhs}, - /*outputs=*/zeroTensor, indexingMaps, iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value l = args[0], r = args[1], res = args[2]; - Value re = b.create(loc, elemType, r); - Value im = b.create(loc, elemType, r); - Value mulRe = b.create(loc, l, re); - Value mulIm = b.create(loc, l, im); - Value mulCplx = b.create( - loc, complexElemType, mulRe, mulIm); - Value add = b.create(loc, mulCplx, res); - b.create(loc, add); - }) + linalg::GenericOp::create( + rewriter, loc, zeroTensor.getType(), + /*inputs=*/ValueRange{lhs, rhs}, + /*outputs=*/zeroTensor, indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value l = args[0], r = args[1], res = args[2]; + Value re = complex::ReOp::create(b, loc, elemType, r); + Value im = complex::ImOp::create(b, loc, elemType, r); + Value mulRe = arith::MulFOp::create(b, loc, l, re); + Value mulIm = arith::MulFOp::create(b, loc, l, im); + Value mulCplx = complex::CreateOp::create(b, loc, complexElemType, + mulRe, mulIm); + Value add = complex::AddOp::create(b, loc, mulCplx, res); + linalg::YieldOp::create(b, loc, add); + }) .getResult(0); // Transpose back diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 3d5ddf91ff40..3917bf49a19f 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -161,7 +161,7 @@ static LogicalResult createPoolingOp( return op->emitError("unimplemented: non-floating point type"); Value initValue = - rewriter.create(loc, cast(initValueAttr)); + arith::ConstantOp::create(rewriter, loc, cast(initValueAttr)); paddedInput = padInputTensor(op, rewriter, self, ceilMode, dimensionality, strideInts, paddingInts, initValue); @@ -173,8 +173,8 @@ static LogicalResult createPoolingOp( auto stridesAttr = rewriter.getI64VectorAttr(strideInts); auto dilationAttr = rewriter.getI64VectorAttr(dilationInts); auto shape = castIntVectorToIndexVector(rewriter, loc, kernelSizeIntValues); - Value windowTensor = rewriter.create( - loc, getAsOpFoldResult(shape), elementType); + Value windowTensor = tensor::EmptyOp::create( + rewriter, loc, getAsOpFoldResult(shape), elementType); Value permutedInput = paddedInput, permutedOutput = outTensorInitialized; if (dimensionality == 3) { @@ -194,12 +194,10 @@ static LogicalResult createPoolingOp( op, "failed to perform permutation of tensor"); } - Value poolingResult = - rewriter - .create(loc, permutedOutput.getType(), - ValueRange{permutedInput, windowTensor}, permutedOutput, - stridesAttr, dilationAttr) - .getResult(0); + Value poolingResult = OpTy::create(rewriter, loc, permutedOutput.getType(), + ValueRange{permutedInput, windowTensor}, + permutedOutput, stridesAttr, dilationAttr) + .getResult(0); result = poolingResult; if (dimensionality == 3) { @@ -236,13 +234,13 @@ static Value createMaxUnpoolOp(Operation *op, int64_t poolingDimensionality, SmallVector outSizePadded; for (auto &&[i, size] : llvm::enumerate(resType.getShape())) { if (int64_t(i) < NC) { - outSizePadded.emplace_back(rewriter.create(loc, self, i)); + outSizePadded.emplace_back(tensor::DimOp::create(rewriter, loc, self, i)); continue; } int64_t pad = padding[i - NC]; outSizePadded.emplace_back( - rewriter.create(loc, size + pad)); + arith::ConstantIndexOp::create(rewriter, loc, size + pad)); } // In case if input tensor size is not divisible by stride @@ -270,18 +268,20 @@ static Value createMaxUnpoolOp(Operation *op, int64_t poolingDimensionality, // Pad the indices tensor with a value which cannot appear in real data // (-1) so it will never match. In this case we can pad self with any // value, as it will never affect the output. - Value zero = rewriter.create( - loc, rewriter.getZeroAttr(selfType.getElementType())); - Value invalidIdx = rewriter.create( - loc, rewriter.getIntegerAttr(indicesType.getElementType(), -1)); + Value zero = arith::ConstantOp::create( + rewriter, loc, rewriter.getZeroAttr(selfType.getElementType())); + Value invalidIdx = arith::ConstantOp::create( + rewriter, loc, + rewriter.getIntegerAttr(indicesType.getElementType(), -1)); self = torch_to_linalg::getPaddedTensor(op, rewriter, self, low, high, zero); indices = torch_to_linalg::getPaddedTensor(op, rewriter, indices, low, high, invalidIdx); } - Value init = rewriter.create( - loc, getAsOpFoldResult(outSizePadded), selfType.getElementType()); + Value init = + tensor::EmptyOp::create(rewriter, loc, getAsOpFoldResult(outSizePadded), + selfType.getElementType()); SmallVector inputExprs; SmallVector outputExprs; @@ -308,24 +308,24 @@ static Value createMaxUnpoolOp(Operation *op, int64_t poolingDimensionality, // values which came from indices tensor. Value ret; for (auto i : llvm::seq(NC, outRank)) { - Value idx = b.create(loc, i); + Value idx = linalg::IndexOp::create(b, loc, i); // If pool input was padded, adjust indices so they start at 0 in the // non-padded area. Indices outside non-padded area will make no sense, // but it doesnt matter as we will cut the padded area later by // extract_slice. int64_t pad = padding[i - NC]; if (pad != 0) { - Value padVal = b.create(loc, pad); - idx = b.create(loc, idx, padVal); + Value padVal = arith::ConstantIndexOp::create(b, loc, pad); + idx = arith::SubIOp::create(b, loc, idx, padVal); } if (!ret) { ret = idx; } else { Value size = - b.create(loc, resType.getShape()[i]); - ret = b.create(loc, ret, size); - ret = b.create(loc, ret, idx); + arith::ConstantIndexOp::create(b, loc, resType.getShape()[i]); + ret = arith::MulIOp::create(b, loc, ret, size); + ret = arith::AddIOp::create(b, loc, ret, idx); } } return ret; @@ -335,24 +335,23 @@ static Value createMaxUnpoolOp(Operation *op, int64_t poolingDimensionality, // Compute current output linear index and compare it with the value // from indices arg. Value input = args[0]; - Value zero = - b.create(loc, rewriter.getZeroAttr(input.getType())); - Value index = b.create(loc, indexType, args[1]); + Value zero = arith::ConstantOp::create( + b, loc, rewriter.getZeroAttr(input.getType())); + Value index = arith::IndexCastOp::create(b, loc, indexType, args[1]); Value currentIndex = computeIndex(b, loc); - Value cmp = b.create(loc, arith::CmpIPredicate::eq, index, - currentIndex); - Value out = b.create(loc, cmp, input, zero); - b.create(loc, out); + Value cmp = arith::CmpIOp::create(b, loc, arith::CmpIPredicate::eq, index, + currentIndex); + Value out = arith::SelectOp::create(b, loc, cmp, input, zero); + linalg::YieldOp::create(b, loc, out); }; Value result = - rewriter - .create(loc, - /*resultTensorTypes=*/init.getType(), - /*inputs=*/ValueRange({self, indices}), - /*outputs=*/init, - /*indexingMaps=*/indexingMaps, - /*iteratorTypes=*/iteratorTypes, builder) + linalg::GenericOp::create(rewriter, loc, + /*resultTensorTypes=*/init.getType(), + /*inputs=*/ValueRange({self, indices}), + /*outputs=*/init, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, builder) .getResult(0); if (llvm::any_of(padding, [](int64_t v) { return v != 0; })) { @@ -368,16 +367,16 @@ static Value createMaxUnpoolOp(Operation *op, int64_t poolingDimensionality, continue; } - sizeVals.emplace_back(rewriter.create(loc, self, i)); + sizeVals.emplace_back(tensor::DimOp::create(rewriter, loc, self, i)); } SmallVector stridesVals(outRank, rewriter.getI64IntegerAttr(1)); - result = rewriter.create(loc, result, offsetVals, - sizeVals, stridesVals); + result = tensor::ExtractSliceOp::create(rewriter, loc, result, offsetVals, + sizeVals, stridesVals); } if (result.getType() != resType) - result = rewriter.create(loc, resType, result); + result = tensor::CastOp::create(rewriter, loc, resType, result); return result; } @@ -445,7 +444,7 @@ class ConvertAtenMaxPoolOp : public OpConversionPattern { APFloat::getInf(cast(elementType).getFloatSemantics(), /*Negative=*/true)); Value initValue = - rewriter.create(op->getLoc(), smallestFPValueAttr); + arith::ConstantOp::create(rewriter, op->getLoc(), smallestFPValueAttr); paddedInput = padInputTensor(op, rewriter, self, ceilMode, 3, strideInts, paddingInts, initValue); @@ -456,8 +455,8 @@ class ConvertAtenMaxPoolOp : public OpConversionPattern { auto shape = castIntVectorToIndexVector(rewriter, op->getLoc(), kernelSizeIntValues); - Value windowTensor = rewriter.create( - op->getLoc(), getAsOpFoldResult(shape), elementType); + Value windowTensor = tensor::EmptyOp::create( + rewriter, op->getLoc(), getAsOpFoldResult(shape), elementType); MLIRContext *context = rewriter.getContext(); @@ -501,20 +500,19 @@ class ConvertAtenMaxPoolOp : public OpConversionPattern { SmallVector(5, utils::IteratorType::parallel); iteratorTypes.append(3, utils::IteratorType::reduction); SmallVector indexingMaps = {mapInput, mapKernel, mapOutput}; - poolingOp = rewriter - .create( - op->getLoc(), - /* result types */ outTensorInitialized.getType(), - /* operands */ ValueRange({paddedInput, windowTensor}), - /* outputs */ outTensorInitialized, - /*indexingMaps=*/indexingMaps, - /*iteratorTypes=*/iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value currentVal = args[0], accMaxValue = args[2]; - Value max_result = b.create( - loc, currentVal, accMaxValue); - b.create(loc, max_result); - }) + poolingOp = linalg::GenericOp::create( + rewriter, op->getLoc(), + /* result types */ outTensorInitialized.getType(), + /* operands */ ValueRange({paddedInput, windowTensor}), + /* outputs */ outTensorInitialized, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value currentVal = args[0], accMaxValue = args[2]; + Value max_result = arith::MaximumFOp::create( + b, loc, currentVal, accMaxValue); + linalg::YieldOp::create(b, loc, max_result); + }) .getResult(0); return success(); @@ -554,8 +552,8 @@ class ConvertAtenMaxPoolOp : public OpConversionPattern { Location loc = op->getLoc(); RankedTensorType indicesRankedTensorType = cast( this->getTypeConverter()->convertType(op->getResult(1).getType())); - Value cstMinusOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); + Value cstMinusOne = arith::ConstantOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(-1)); Value indicesTensor = createInitTensor(rewriter, loc, outTensorShape, indicesRankedTensorType.getElementType(), cstMinusOne); @@ -569,9 +567,9 @@ class ConvertAtenMaxPoolOp : public OpConversionPattern { SmallVector kernelStride = getAsConstantIndexValues(rewriter, loc, strideInts); - Value windowTensor = rewriter.create( - loc, getAsOpFoldResult(kernelSize), - indicesRankedTensorType.getElementType()); + Value windowTensor = + tensor::EmptyOp::create(rewriter, loc, getAsOpFoldResult(kernelSize), + indicesRankedTensorType.getElementType()); SmallVector inputExprs, outputExprs, kernelExprs; for (unsigned i = 0; i < rank; i++) { @@ -599,60 +597,57 @@ class ConvertAtenMaxPoolOp : public OpConversionPattern { } indicesResult = - rewriter - .create( - loc, /*resultTensorTypes=*/indicesTensor.getType(), - /*inputs=*/ValueRange({maxPool, windowTensor}), - /*outputs=*/indicesTensor, - /*indexingMaps=*/indexingMaps, - /*iteratorTypes=*/iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value maxVal = args[0], res = args[2]; - - SmallVector inputDims; - inputDims.append({b.create(loc, 0), - b.create(loc, 1)}); - for (unsigned i = 2; i < rank; i++) { - Value mainIndex = b.create(loc, i); - Value subIndex = - b.create(loc, i + rank - 2); - Value origin = b.create(loc, mainIndex, - kernelStride[i - 2]); - Value offset = - b.create(loc, subIndex, dilation[i - 2]); - inputDims.push_back( - b.create(loc, origin, offset)); - } - - Value input = - b.create(loc, paddedInput, inputDims); - Value pred = b.create( - loc, arith::CmpFPredicate::OEQ, input, maxVal); - - Value outIndex = - b.create(loc, b.getIndexAttr(0)); - Value curInputStride = - b.create(loc, b.getIndexAttr(1)); - for (unsigned i = 0; i < rank - 2; i++) { - Value minusPadding = b.create( - loc, inputDims[rank - 1 - i], padding[rank - 3 - i]); - Value timesStride = b.create( - loc, minusPadding, curInputStride); - outIndex = - b.create(loc, outIndex, timesStride); - curInputStride = b.create( - loc, curInputStride, inputSubShape[rank - 3 - i]); - } - Value result = b.create( - loc, pred, castIndexToInt64(b, loc, outIndex), res); - - Value predInvalidIndex = b.create( - loc, arith::CmpIPredicate::eq, res, cstMinusOne); - Value out = b.create(loc, predInvalidIndex, - result, res); - - b.create(loc, out); - }) + linalg::GenericOp::create( + rewriter, loc, /*resultTensorTypes=*/indicesTensor.getType(), + /*inputs=*/ValueRange({maxPool, windowTensor}), + /*outputs=*/indicesTensor, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value maxVal = args[0], res = args[2]; + + SmallVector inputDims; + inputDims.append({linalg::IndexOp::create(b, loc, 0), + linalg::IndexOp::create(b, loc, 1)}); + for (unsigned i = 2; i < rank; i++) { + Value mainIndex = linalg::IndexOp::create(b, loc, i); + Value subIndex = linalg::IndexOp::create(b, loc, i + rank - 2); + Value origin = arith::MulIOp::create(b, loc, mainIndex, + kernelStride[i - 2]); + Value offset = + arith::MulIOp::create(b, loc, subIndex, dilation[i - 2]); + inputDims.push_back( + arith::AddIOp::create(b, loc, origin, offset)); + } + + Value input = + tensor::ExtractOp::create(b, loc, paddedInput, inputDims); + Value pred = arith::CmpFOp::create( + b, loc, arith::CmpFPredicate::OEQ, input, maxVal); + + Value outIndex = + arith::ConstantOp::create(b, loc, b.getIndexAttr(0)); + Value curInputStride = + arith::ConstantOp::create(b, loc, b.getIndexAttr(1)); + for (unsigned i = 0; i < rank - 2; i++) { + Value minusPadding = arith::SubIOp::create( + b, loc, inputDims[rank - 1 - i], padding[rank - 3 - i]); + Value timesStride = + arith::MulIOp::create(b, loc, minusPadding, curInputStride); + outIndex = arith::AddIOp::create(b, loc, outIndex, timesStride); + curInputStride = arith::MulIOp::create( + b, loc, curInputStride, inputSubShape[rank - 3 - i]); + } + Value result = arith::SelectOp::create( + b, loc, pred, castIndexToInt64(b, loc, outIndex), res); + + Value predInvalidIndex = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::eq, res, cstMinusOne); + Value out = arith::SelectOp::create(b, loc, predInvalidIndex, + result, res); + + linalg::YieldOp::create(b, loc, out); + }) .getResult(0); return success(); @@ -732,8 +727,8 @@ class ConvertAtenMaxPoolOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "unable to compute maxpool3d"); } - Value outMaxPool = rewriter.create( - op->getLoc(), maxPoolResultType, maxPool); + Value outMaxPool = tensor::CastOp::create(rewriter, op->getLoc(), + maxPoolResultType, maxPool); SmallVector outResult({outMaxPool}); if (withIndices) { Value indicesResult; @@ -745,8 +740,8 @@ class ConvertAtenMaxPoolOp : public OpConversionPattern { "unable to compute maxpool indices"); Type indicesResultType = typeConverter->convertType(op->getResult(1).getType()); - Value outIndices = rewriter.create( - op->getLoc(), indicesResultType, indicesResult); + Value outIndices = tensor::CastOp::create( + rewriter, op->getLoc(), indicesResultType, indicesResult); outResult.push_back(outIndices); } rewriter.replaceOp(op, outResult); @@ -1007,9 +1002,8 @@ Value PoolSizeCalculator::getPoolSize( // change, these variables used "height" and "width" (or "h" and "w") // in these intermediate variables instead of "Dim". - Value IndexODim = - b.create(location, - /*value=*/SumPoolTypeDimIndex[i]); + Value IndexODim = linalg::IndexOp::create(b, location, + /*value=*/SumPoolTypeDimIndex[i]); Value ODim = castIndexToInt64(b, location, IndexODim); Value DDim = b.createOrFold( location, b.getI64IntegerAttr(strideInts[i])); @@ -1138,8 +1132,8 @@ LogicalResult ConvertAtenAvgPoolOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "unable to compute sumpool"); // Compute the average of sumPool. - Value outputTensor = rewriter.create( - loc, getAsOpFoldResult(outTensorShape), resultElementType); + Value outputTensor = tensor::EmptyOp::create( + rewriter, loc, getAsOpFoldResult(outTensorShape), resultElementType); SmallVector indexingMapsAvg( 2, rewriter.getMultiDimIdentityMap(Dim + 2)); SmallVector iteratorTypesAvg( @@ -1233,25 +1227,24 @@ LogicalResult ConvertAtenAvgPoolOp:: } Value avgPool = - rewriter - .create( - loc, outputTensor.getType(), sumPool, outputTensor, - /*indexingMaps=*/indexingMapsAvg, - /*iteratorTypes=*/iteratorTypesAvg, - [&](OpBuilder &b, Location loc, ValueRange args) { - if (!poolSize) { - poolSize = poolSizeCalculator.getPoolSize( - b, kernelDimSizes, strideInts, paddingInts); - } - Value divisor = - convertScalarToDtype(b, loc, poolSize, resultElementType); - Value avg; - if (isa(resultElementType)) - avg = b.createOrFold(loc, args[0], divisor); - else if (isa(resultElementType)) - avg = b.createOrFold(loc, args[0], divisor); - b.createOrFold(loc, avg); - }) + linalg::GenericOp::create( + rewriter, loc, outputTensor.getType(), sumPool, outputTensor, + /*indexingMaps=*/indexingMapsAvg, + /*iteratorTypes=*/iteratorTypesAvg, + [&](OpBuilder &b, Location loc, ValueRange args) { + if (!poolSize) { + poolSize = poolSizeCalculator.getPoolSize( + b, kernelDimSizes, strideInts, paddingInts); + } + Value divisor = + convertScalarToDtype(b, loc, poolSize, resultElementType); + Value avg; + if (isa(resultElementType)) + avg = b.createOrFold(loc, args[0], divisor); + else if (isa(resultElementType)) + avg = b.createOrFold(loc, args[0], divisor); + b.createOrFold(loc, avg); + }) .getResult(0); rewriter.replaceOpWithNewOp(op, resultType, avgPool); return success(); @@ -1284,19 +1277,18 @@ LogicalResult ConvertAtenAvgPoolOp:: divisor = convertScalarToDtype(rewriter, loc, divisor, resultElementType); Value avgPool = - rewriter - .create( - loc, outputTensor.getType(), sumPool, outputTensor, - /*indexingMaps=*/indexingMapsAvg, - /*iteratorTypes=*/iteratorTypesAvg, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value avg; - if (isa(resultElementType)) - avg = b.create(loc, args[0], divisor); - else if (isa(resultElementType)) - avg = b.create(loc, args[0], divisor); - b.create(loc, avg); - }) + linalg::GenericOp::create( + rewriter, loc, outputTensor.getType(), sumPool, outputTensor, + /*indexingMaps=*/indexingMapsAvg, + /*iteratorTypes=*/iteratorTypesAvg, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value avg; + if (isa(resultElementType)) + avg = arith::DivSIOp::create(b, loc, args[0], divisor); + else if (isa(resultElementType)) + avg = arith::DivFOp::create(b, loc, args[0], divisor); + linalg::YieldOp::create(b, loc, avg); + }) .getResult(0); rewriter.replaceOpWithNewOp(op, resultType, avgPool); return success(); @@ -1382,10 +1374,10 @@ class AdaptiveMaxPoolingHelper : public AdaptivePoolingHelper { elementType, APFloat::getInf(cast(elementType).getFloatSemantics(), /*Negative=*/true)); - buffVal = rewriter.create(loc, elementType, - smallestFPValueAttr); - auxTensor = rewriter.create( - loc, getAsOpFoldResult(outputSizes), auxTensorElementType); + buffVal = arith::ConstantOp::create(rewriter, loc, elementType, + smallestFPValueAttr); + auxTensor = tensor::EmptyOp::create( + rewriter, loc, getAsOpFoldResult(outputSizes), auxTensorElementType); for (unsigned i = 0; i < rank; i++) { auxTensorExprs.push_back(rewriter.getAffineDimExpr(i)); } @@ -1400,8 +1392,8 @@ class AdaptiveMaxPoolingHelper : public AdaptivePoolingHelper { Value &out2, Value &auxOut) { // compute max using select, since cond1 will be used for indices Value cond1 = - b.create(loc, arith::CmpFPredicate::OGT, inElt, res); - out2 = b.create(loc, cond1, inElt, res); + arith::CmpFOp::create(b, loc, arith::CmpFPredicate::OGT, inElt, res); + out2 = arith::SelectOp::create(b, loc, cond1, inElt, res); // index in different dims (n x c x d x h x w) // 1d: (iw) // 2d: (ih*W + iw) @@ -1409,12 +1401,12 @@ class AdaptiveMaxPoolingHelper : public AdaptivePoolingHelper { Value currIndex = inputElementIndices[nonSpatial]; for (unsigned i = 0; i < rank - nonSpatial - 1; i++) { Value prevTimesNewSize = - b.create(loc, currIndex, inputSpatialSizes[i + 1]); - currIndex = b.create( - loc, prevTimesNewSize, inputElementIndices[nonSpatial + i + 1]); + arith::MulIOp::create(b, loc, currIndex, inputSpatialSizes[i + 1]); + currIndex = arith::AddIOp::create( + b, loc, prevTimesNewSize, inputElementIndices[nonSpatial + i + 1]); } Value indexOut1Int = castIndexToInt64(b, loc, currIndex); - auxOut = b.create(loc, cond1, indexOut1Int, maxIndex); + auxOut = arith::SelectOp::create(b, loc, cond1, indexOut1Int, maxIndex); return success(); } @@ -1427,9 +1419,9 @@ class AdaptiveMaxPoolingHelper : public AdaptivePoolingHelper { const SmallVector &outputExprs) { Location loc = op->getLoc(); Value maxValues = - rewriter.create(loc, outputType, adaptivePoolOutput); + tensor::CastOp::create(rewriter, loc, outputType, adaptivePoolOutput); Value outputIndices = - rewriter.create(loc, auxTensorType, auxTensorReturn); + tensor::CastOp::create(rewriter, loc, auxTensorType, auxTensorReturn); rewriter.replaceOp(op, {maxValues, outputIndices}); return success(); } @@ -1457,10 +1449,10 @@ class AdaptiveAvgPoolingHelper : public AdaptivePoolingHelper { const TypeConverter *typeConverter = opConversionPattern.getTypeConverter(); outputType = cast( typeConverter->convertType(op.getResult().getType())); - buffVal = rewriter.create( - loc, elementType, rewriter.getFloatAttr(elementType, 0)); - auxTensor = rewriter.create( - loc, getAsOpFoldResult(outShapeIndexVector), elementType); + buffVal = arith::ConstantOp::create(rewriter, loc, elementType, + rewriter.getFloatAttr(elementType, 0)); + auxTensor = tensor::EmptyOp::create( + rewriter, loc, getAsOpFoldResult(outShapeIndexVector), elementType); for (unsigned i = nonSpatial; i < rank; i++) { auxTensorExprs.push_back(rewriter.getAffineDimExpr(i)); } @@ -1473,14 +1465,14 @@ class AdaptiveAvgPoolingHelper : public AdaptivePoolingHelper { const SmallVector &inputSpatialSizes, const Value &indexOne, const SmallVector &starts, const SmallVector &ends, Value &out2, Value &auxOut) { - out2 = b.create(loc, inElt, res); + out2 = arith::AddFOp::create(b, loc, inElt, res); Value kernelVolume = indexOne; for (unsigned i = 0; i < rank - nonSpatial; i++) { - Value currSize = b.create(loc, ends[i], starts[i]); - kernelVolume = b.create(loc, kernelVolume, currSize); + Value currSize = arith::SubIOp::create(b, loc, ends[i], starts[i]); + kernelVolume = arith::MulIOp::create(b, loc, kernelVolume, currSize); } Value auxOutSI = castIndexToInt64(b, loc, kernelVolume); - auxOut = b.create(loc, elementType, auxOutSI); + auxOut = arith::SIToFPOp::create(b, loc, elementType, auxOutSI); return success(); } @@ -1497,15 +1489,15 @@ class AdaptiveAvgPoolingHelper : public AdaptivePoolingHelper { {auxTensorExprs, outputExprs}, op.getContext()); SmallVector iteratorTypes1( rank, utils::IteratorType::parallel); - auto output = rewriter.create( - loc, /*resultTensorTypes=*/adaptivePoolOutput.getType(), + auto output = linalg::GenericOp::create( + rewriter, loc, /*resultTensorTypes=*/adaptivePoolOutput.getType(), /*inputs=*/auxTensorReturn, /*outputs=*/adaptivePoolOutput, /*indexingMaps=*/indexingMaps1, /*iteratorTypes=*/iteratorTypes1, [&](OpBuilder &b, Location loc, ValueRange args) { - Value q = b.create(loc, args[1], args[0]); - b.create(loc, q); + Value q = arith::DivFOp::create(b, loc, args[1], args[0]); + linalg::YieldOp::create(b, loc, q); }); rewriter.replaceOpWithNewOp(op, outputType, @@ -1608,18 +1600,18 @@ class ConvertAtenAdaptivePoolOp : public OpConversionPattern { Type boolType = rewriter.getI1Type(); SmallVector kIterSizeVector; Value constantOne = - rewriter.create(loc, rewriter.getIndexAttr(1)); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(1)); for (int i = 0; i < rank - nonSpatial; i++) { - Value hInPlusOne = rewriter.create( - loc, inputSpatialSizes[i], constantOne); - Value kMaxMinusOne = rewriter.create( - loc, hInPlusOne, outShapeIndexVector[i]); + Value hInPlusOne = arith::SubIOp::create( + rewriter, loc, inputSpatialSizes[i], constantOne); + Value kMaxMinusOne = arith::CeilDivSIOp::create(rewriter, loc, hInPlusOne, + outShapeIndexVector[i]); Value kMax = - rewriter.create(loc, constantOne, kMaxMinusOne); + arith::AddIOp::create(rewriter, loc, constantOne, kMaxMinusOne); kIterSizeVector.push_back(kMax); } - Value kIter = rewriter.create( - loc, getAsOpFoldResult(kIterSizeVector), boolType); + Value kIter = tensor::EmptyOp::create( + rewriter, loc, getAsOpFoldResult(kIterSizeVector), boolType); // get output sizes used for initializing some tensors SmallVector outputSizes; @@ -1676,12 +1668,12 @@ class ConvertAtenAdaptivePoolOp : public OpConversionPattern { for (unsigned i = 0; i < rank - nonSpatial; i++) { iteratorTypes.push_back(utils::IteratorType::reduction); } - Value indexOne = rewriter.create(loc, 1); + Value indexOne = arith::ConstantIndexOp::create(rewriter, loc, 1); bool failedCustomization = false; // adaptive pooling generic op - auto adaptivePool = rewriter.create( - loc, /*resultTensorTypes=*/ + auto adaptivePool = linalg::GenericOp::create( + rewriter, loc, /*resultTensorTypes=*/ TypeRange({initOutput.getType(), auxTensor.getType()}), /*inputs=*/ValueRange({kIter}), /*outputs=*/ValueRange({initOutput, auxTensor}), @@ -1692,26 +1684,26 @@ class ConvertAtenAdaptivePoolOp : public OpConversionPattern { Value maxIndex = args[2]; SmallVector ind; for (unsigned i = 0; i < 2 * rank - nonSpatial; i++) { - ind.push_back(b.create(loc, i)); + ind.push_back(linalg::IndexOp::create(b, loc, i)); } // compute start and end indices // st = s1( s0(ind2 * Hin) // Hout ) SmallVector starts; SmallVector ends; for (unsigned i = nonSpatial; i < rank; i++) { - Value s0 = b.create( - loc, ind[i], inputSpatialSizes[i - nonSpatial]); - Value s1 = b.create( - loc, s0, outShapeIndexVector[i - nonSpatial]); + Value s0 = arith::MulIOp::create(b, loc, ind[i], + inputSpatialSizes[i - nonSpatial]); + Value s1 = arith::FloorDivSIOp::create( + b, loc, s0, outShapeIndexVector[i - nonSpatial]); starts.push_back(s1); // en = e4( 1 + e3( e2( e1( e0(ind2 + 1) * hIn ) - 1 ) // hOut ) ) - Value e0 = b.create(loc, ind[i], indexOne); - Value e1 = b.create( - loc, e0, inputSpatialSizes[i - nonSpatial]); - Value e2 = b.create(loc, e1, indexOne); - Value e3 = b.create( - loc, e2, outShapeIndexVector[i - nonSpatial]); - Value e4 = b.create(loc, indexOne, e3); + Value e0 = arith::AddIOp::create(b, loc, ind[i], indexOne); + Value e1 = arith::MulIOp::create(b, loc, e0, + inputSpatialSizes[i - nonSpatial]); + Value e2 = arith::SubIOp::create(b, loc, e1, indexOne); + Value e3 = arith::FloorDivSIOp::create( + b, loc, e2, outShapeIndexVector[i - nonSpatial]); + Value e4 = arith::AddIOp::create(b, loc, indexOne, e3); ends.push_back(e4); } // extract input element @@ -1720,18 +1712,18 @@ class ConvertAtenAdaptivePoolOp : public OpConversionPattern { inputElementIndices.push_back(ind[i]); } for (unsigned i = nonSpatial; i < rank; i++) { - inputElementIndices.push_back(b.create( - loc, starts[i - nonSpatial], ind[rank - nonSpatial + i])); + inputElementIndices.push_back(arith::AddIOp::create( + b, loc, starts[i - nonSpatial], ind[rank - nonSpatial + i])); } - Value inElt = b.create(loc, elementType, buffInput, - inputElementIndices); + Value inElt = tensor::ExtractOp::create( + b, loc, elementType, buffInput, inputElementIndices); // check if we extracted at windex < end index for (unsigned i = 0; i < rank - nonSpatial; i++) { - Value cond = b.create( - loc, arith::CmpIPredicate(6), + Value cond = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate(6), inputElementIndices[i + nonSpatial], ends[i]); // if out-of-bounds, replace the extracted element with buffVal - inElt = b.create(loc, cond, inElt, buffVal); + inElt = arith::SelectOp::create(b, loc, cond, inElt, buffVal); } Value out2, auxOut; // customize for max vs. avg: @@ -1740,7 +1732,7 @@ class ConvertAtenAdaptivePoolOp : public OpConversionPattern { inputSpatialSizes, indexOne, starts, ends, out2, auxOut))) { failedCustomization = true; } - b.create(loc, ValueRange({out2, auxOut})); + linalg::YieldOp::create(b, loc, ValueRange({out2, auxOut})); }); if (failedCustomization) { diff --git a/lib/Conversion/TorchToLinalg/Random.cpp b/lib/Conversion/TorchToLinalg/Random.cpp index 854e3f86d367..f33b7436041d 100644 --- a/lib/Conversion/TorchToLinalg/Random.cpp +++ b/lib/Conversion/TorchToLinalg/Random.cpp @@ -58,14 +58,14 @@ static Value toLinearIndex(OpBuilder &b, Location loc, assert(indicesIntValues.size() == shapeIntValues.size() && "Expected `indices` and `shape` to have the same size"); Value result = - b.create(loc, b.getZeroAttr(b.getI64Type())); + arith::ConstantOp::create(b, loc, b.getZeroAttr(b.getI64Type())); for (auto [index, stride] : llvm::zip(indicesIntValues, shapeIntValues)) { assert(isa(index.getType()) && isa(stride.getType()) && "Input arrays to `toLinearIndex` must only contain values of type " "`mlir::IntegerType`"); - Value mul = b.create(loc, result, stride); - result = b.create(loc, mul, index); + Value mul = arith::MulIOp::create(b, loc, result, stride); + result = arith::AddIOp::create(b, loc, mul, index); } return result; } @@ -75,22 +75,22 @@ static Value toLinearIndex(OpBuilder &b, Location loc, static Value randomUniformUInt(OpBuilder &b, Location loc, Value ctr, Value key) { auto mul = [&](Value lhs, Value rhs) -> Value { - return b.create(loc, lhs, rhs); + return arith::MulIOp::create(b, loc, lhs, rhs); }; auto add = [&](Value lhs, Value rhs) -> Value { - return b.create(loc, lhs, rhs); + return arith::AddIOp::create(b, loc, lhs, rhs); }; - Value cst32 = b.create(loc, b.getI64IntegerAttr(32)); + Value cst32 = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(32)); auto shiftRight32 = [&](Value val) -> Value { - return b.create(loc, val, cst32); + return arith::ShRUIOp::create(b, loc, val, cst32); }; auto swapLoHi = [&](Value val) -> Value { - Value leftShift = b.create(loc, val, cst32); + Value leftShift = arith::ShLIOp::create(b, loc, val, cst32); Value rightShift = shiftRight32(val); - return b.create(loc, leftShift, rightShift); + return arith::OrIOp::create(b, loc, leftShift, rightShift); }; auto bitwiseXOr = [&](Value lhs, Value rhs) -> Value { - return b.create(loc, lhs, rhs); + return arith::XOrIOp::create(b, loc, lhs, rhs); }; Value t, x, y, z; @@ -115,14 +115,15 @@ static Value randomUniformF64(OpBuilder &b, Location loc, Value ctr, Value key, // scale = (max - min) * const(F64, 5.4210108E-20) // which is derived from rand(min,max) = // rand()/(RAND_MAX/(max-min)) where RAND_MAX = 2^64 - 1 - Value epsilon = b.create( - loc, b.getFloatAttr(b.getF64Type(), 5.4210108E-20)); - Value range = b.create(loc, max, min); - Value scale = b.create(loc, range, epsilon); + Value epsilon = arith::ConstantOp::create( + b, loc, b.getFloatAttr(b.getF64Type(), 5.4210108E-20)); + Value range = arith::SubFOp::create(b, loc, max, min); + Value scale = arith::MulFOp::create(b, loc, range, epsilon); // res = cast(F64, tempN) * scale + min - Value updateFloat = b.create(loc, b.getF64Type(), randomVal); - Value updateScaled = b.create(loc, updateFloat, scale); - Value uniformSample = b.create(loc, updateScaled, min); + Value updateFloat = + arith::UIToFPOp::create(b, loc, b.getF64Type(), randomVal); + Value updateScaled = arith::MulFOp::create(b, loc, updateFloat, scale); + Value uniformSample = arith::AddFOp::create(b, loc, updateScaled, min); return uniformSample; } @@ -153,7 +154,7 @@ class ConvertAtenUniformOp : public OpConversionPattern { op, "The generator has to be None because only global default " "generator is supported"); // Get key, min and max used by `linalg.generic` compute payload. - Value key = rewriter.create(loc); + Value key = TorchConversion::GetNextSeedOp::create(rewriter, loc); Value min = convertScalarToDtype(rewriter, loc, from, f64Ty); Value max = convertScalarToDtype(rewriter, loc, to, f64Ty); @@ -166,30 +167,28 @@ class ConvertAtenUniformOp : public OpConversionPattern { SmallVector sizes = getTensorSizes(rewriter, loc, self); SmallVector sizesIntValues = castIndexVectorToInt64Vector(rewriter, loc, sizes); - Value initTensor = - rewriter.create(loc, getAsOpFoldResult(sizes), elemTy); + Value initTensor = tensor::EmptyOp::create( + rewriter, loc, getAsOpFoldResult(sizes), elemTy); Value uniformRes = - rewriter - .create( - loc, initTensor.getType(), /*inputs=*/ValueRange{}, - /*outputs=*/initTensor, indexingMaps, iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - SmallVector indicesIntValues; - for (int i = 0; i < resultRank; i++) { - indicesIntValues.push_back(castIndexToInt64( - b, loc, b.create(loc, i))); - } - - Value linearIndex = - toLinearIndex(b, loc, indicesIntValues, sizesIntValues); - - Value res = - randomUniformF64(b, loc, linearIndex, key, min, max); - Value truncRes = res; - if (isa(elemTy)) - truncRes = b.create(loc, elemTy, res); - b.create(loc, truncRes); - }) + linalg::GenericOp::create( + rewriter, loc, initTensor.getType(), /*inputs=*/ValueRange{}, + /*outputs=*/initTensor, indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + SmallVector indicesIntValues; + for (int i = 0; i < resultRank; i++) { + indicesIntValues.push_back(castIndexToInt64( + b, loc, linalg::IndexOp::create(b, loc, i))); + } + + Value linearIndex = + toLinearIndex(b, loc, indicesIntValues, sizesIntValues); + + Value res = randomUniformF64(b, loc, linearIndex, key, min, max); + Value truncRes = res; + if (isa(elemTy)) + truncRes = arith::TruncFOp::create(b, loc, elemTy, res); + linalg::YieldOp::create(b, loc, truncRes); + }) .getResult(0); Type newResultType = getTypeConverter()->convertType(op.getType()); @@ -246,14 +245,14 @@ class ConvertAtenMultinomialOp : public OpConversionPattern { op, "torch.multinomial accepts only rank 1 or 2 tensors as weights"); } - Value cstZero = rewriter.create( - loc, i64Ty, rewriter.getI64IntegerAttr(0)); - Value cstOne = rewriter.create( - loc, i64Ty, rewriter.getI64IntegerAttr(1)); - Value zeroIndex = rewriter.create(loc, 0); - Value oneIndex = rewriter.create(loc, 1); + Value cstZero = arith::ConstantOp::create(rewriter, loc, i64Ty, + rewriter.getI64IntegerAttr(0)); + Value cstOne = arith::ConstantOp::create(rewriter, loc, i64Ty, + rewriter.getI64IntegerAttr(1)); + Value zeroIndex = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value oneIndex = arith::ConstantIndexOp::create(rewriter, loc, 1); Value numSamplesIndex = - rewriter.create(loc, indexTy, numSamples); + arith::IndexCastOp::create(rewriter, loc, indexTy, numSamples); Value numDistributions; Value numCategoriesIndex; @@ -261,22 +260,22 @@ class ConvertAtenMultinomialOp : public OpConversionPattern { if (inputRank == 1) { numDistributions = cstOne; numCategoriesIndex = - rewriter.create(loc, indexTy, self, zeroIndex); + tensor::DimOp::create(rewriter, loc, indexTy, self, zeroIndex); resultShape = ValueRange{numSamplesIndex}; } else { Value numDistIndex = - rewriter.create(loc, indexTy, self, zeroIndex); + tensor::DimOp::create(rewriter, loc, indexTy, self, zeroIndex); numCategoriesIndex = - rewriter.create(loc, indexTy, self, oneIndex); + tensor::DimOp::create(rewriter, loc, indexTy, self, oneIndex); numDistributions = - rewriter.create(loc, i64Ty, numDistIndex); + arith::IndexCastOp::create(rewriter, loc, i64Ty, numDistIndex); resultShape = ValueRange{numDistIndex, numSamplesIndex}; } Value numCategories = - rewriter.create(loc, i64Ty, numCategoriesIndex); - Value resultTensor = rewriter.create( - loc, getAsOpFoldResult(resultShape), i64Ty); + arith::IndexCastOp::create(rewriter, loc, i64Ty, numCategoriesIndex); + Value resultTensor = tensor::EmptyOp::create( + rewriter, loc, getAsOpFoldResult(resultShape), i64Ty); // sum weights for normalization torch_to_linalg::ReductionOpInfo opInfo; @@ -285,8 +284,8 @@ class ConvertAtenMultinomialOp : public OpConversionPattern { else opInfo = {false, self, {1}}; - Value initSum = rewriter.create( - loc, f64Ty, rewriter.getF64FloatAttr(0.0)); + Value initSum = arith::ConstantOp::create(rewriter, loc, f64Ty, + rewriter.getF64FloatAttr(0.0)); int64_t srcWidth = cast(elemTy).getWidth(); if (srcWidth > 64) op->emitWarning("Op bitwidth will be truncated from " + @@ -294,12 +293,12 @@ class ConvertAtenMultinomialOp : public OpConversionPattern { auto sumBody = [&](OpBuilder &b, Location loc, ValueRange payloadArgs) { Value input = payloadArgs[0]; if (srcWidth < 64) - input = b.create(loc, f64Ty, input); + input = arith::ExtFOp::create(b, loc, f64Ty, input); if (srcWidth > 64) - input = b.create(loc, f64Ty, input); + input = arith::TruncFOp::create(b, loc, f64Ty, input); Value result = payloadArgs[1]; - Value nextSum = b.create(loc, input, result); - b.create(loc, nextSum); + Value nextSum = arith::AddFOp::create(b, loc, input, result); + linalg::YieldOp::create(b, loc, nextSum); }; Value sumWeights = torch_to_linalg::createReductionLinalgGeneric( rewriter, loc, opInfo, initSum, sumBody); @@ -307,66 +306,66 @@ class ConvertAtenMultinomialOp : public OpConversionPattern { // Get multinomial samples for each weight vector auto multinomialComputation = [&](OpBuilder &b, Location loc, Value j, ValueRange args) { - Value jIndex = b.create(loc, indexTy, j); + Value jIndex = arith::IndexCastOp::create(b, loc, indexTy, j); Value sum; if (inputRank == 1) { - sum = b.create(loc, sumWeights, ValueRange{}); + sum = tensor::ExtractOp::create(b, loc, sumWeights, ValueRange{}); } else { - sum = b.create(loc, sumWeights, ValueRange{jIndex}); + sum = tensor::ExtractOp::create(b, loc, sumWeights, ValueRange{jIndex}); } // compute cdf in loop - Value initCdf = b.create( - loc, getAsOpFoldResult(ValueRange{numCategoriesIndex}), f64Ty); + Value initCdf = tensor::EmptyOp::create( + b, loc, getAsOpFoldResult(ValueRange{numCategoriesIndex}), f64Ty); Value cdf = - b.create( - loc, cstZero, numCategories, cstOne, ValueRange{initCdf}, - [&](OpBuilder &b, Location loc, Value i, ValueRange vals) { - Value distribution = vals[0]; - // if (i > 0) - auto comparisonPredicate = arith::CmpIPredicateAttr::get( - b.getContext(), arith::CmpIPredicate::sgt); - Value condition = b.create( - loc, comparisonPredicate, i, cstZero); - Value iIndex = b.create(loc, indexTy, i); - // curr_cum = i > 0 ? prob[i] + prob[i-1] : prob[i] - ValueRange ind; - if (inputRank == 1) { - ind = ValueRange{iIndex}; - } else { - ind = ValueRange{jIndex, iIndex}; - } - Value currWeight = b.create(loc, self, ind); - if (srcWidth < 64) - currWeight = b.create(loc, f64Ty, currWeight); - if (srcWidth > 64) - currWeight = - b.create(loc, f64Ty, currWeight); - Value currMass = b.create(loc, currWeight, sum); - Value currCum = - b.create( - loc, condition, - [&](OpBuilder &b, Location loc) { - Value prevI = - b.create(loc, i, cstOne); - Value prevIndex = b.create( - loc, indexTy, prevI); - Value prevMass = b.create( - loc, distribution, ValueRange{prevIndex}); - Value currSum = b.create( - loc, currMass, prevMass); - b.create(loc, ValueRange(currSum)); - }, - [&](OpBuilder &b, Location loc) { - b.create(loc, ValueRange{currMass}); - }) - .getResult(0); - - Value updatedCdf = b.create( - loc, currCum, distribution, ValueRange(iIndex)); - b.create(loc, ValueRange(updatedCdf)); - }) + scf::ForOp::create( + b, loc, cstZero, numCategories, cstOne, ValueRange{initCdf}, + [&](OpBuilder &b, Location loc, Value i, ValueRange vals) { + Value distribution = vals[0]; + // if (i > 0) + auto comparisonPredicate = arith::CmpIPredicateAttr::get( + b.getContext(), arith::CmpIPredicate::sgt); + Value condition = arith::CmpIOp::create( + b, loc, comparisonPredicate, i, cstZero); + Value iIndex = arith::IndexCastOp::create(b, loc, indexTy, i); + // curr_cum = i > 0 ? prob[i] + prob[i-1] : prob[i] + ValueRange ind; + if (inputRank == 1) { + ind = ValueRange{iIndex}; + } else { + ind = ValueRange{jIndex, iIndex}; + } + Value currWeight = tensor::ExtractOp::create(b, loc, self, ind); + if (srcWidth < 64) + currWeight = arith::ExtFOp::create(b, loc, f64Ty, currWeight); + if (srcWidth > 64) + currWeight = + arith::TruncFOp::create(b, loc, f64Ty, currWeight); + Value currMass = arith::DivFOp::create(b, loc, currWeight, sum); + Value currCum = + scf::IfOp::create( + b, loc, condition, + [&](OpBuilder &b, Location loc) { + Value prevI = + arith::SubIOp::create(b, loc, i, cstOne); + Value prevIndex = arith::IndexCastOp::create( + b, loc, indexTy, prevI); + Value prevMass = tensor::ExtractOp::create( + b, loc, distribution, ValueRange{prevIndex}); + Value currSum = + arith::AddFOp::create(b, loc, currMass, prevMass); + scf::YieldOp::create(b, loc, ValueRange(currSum)); + }, + [&](OpBuilder &b, Location loc) { + scf::YieldOp::create(b, loc, ValueRange{currMass}); + }) + .getResult(0); + + Value updatedCdf = tensor::InsertOp::create( + b, loc, currCum, distribution, ValueRange(iIndex)); + scf::YieldOp::create(b, loc, ValueRange(updatedCdf)); + }) .getResult(0); /* @@ -383,128 +382,120 @@ class ConvertAtenMultinomialOp : public OpConversionPattern { * */ // Get key, min and max used by RNG. - Value key = b.create(loc); - Value min = b.create(loc, f64Ty, - rewriter.getF64FloatAttr(0.0)); - Value max = b.create(loc, f64Ty, - rewriter.getF64FloatAttr(1.0)); + Value key = TorchConversion::GetNextSeedOp::create(b, loc); + Value min = arith::ConstantOp::create(b, loc, f64Ty, + rewriter.getF64FloatAttr(0.0)); + Value max = arith::ConstantOp::create(b, loc, f64Ty, + rewriter.getF64FloatAttr(1.0)); // iterate and sample class indices Value result = args[0]; Value finalResult = - rewriter - .create( - loc, cstZero, numSamples, cstOne, ValueRange{result}, - [&](OpBuilder &b, Location loc, Value i, ValueRange args) { - // Sample random float - Value uniformSample = - randomUniformF64(b, loc, i, key, min, max); - - // binary search in cdf to find our sample - Value left = b.create( - loc, i64Ty, b.getI64IntegerAttr(0)); - Value right = numCategories; - - auto checkCondition = [&](OpBuilder &b, Location loc, - ValueRange vals) { - Value left = vals[0]; - Value right = vals[1]; - - // while (right > left) - auto comparisonPredicate = arith::CmpIPredicateAttr::get( - b.getContext(), arith::CmpIPredicate::sgt); - Value loopCondition = b.create( - loc, comparisonPredicate, right, left); - b.create(loc, loopCondition, vals); - }; - - ValueRange whileResults = - b.create( - loc, TypeRange{i64Ty, i64Ty}, - ValueRange{left, right}, checkCondition, - [&](OpBuilder &b, Location loc, ValueRange vals) { - Value left = vals[0]; - Value right = vals[1]; - - Value two = b.create( - loc, i64Ty, b.getI64IntegerAttr(2)); - Value diff = - b.create(loc, right, left); - Value diffMid = - b.create(loc, diff, two); - Value midPointer = - b.create(loc, left, diffMid); - Type indexTy = b.getIndexType(); - Value midIndex = b.create( - loc, indexTy, midPointer); - - // branch and update search indices - auto thenBlock = [&](OpBuilder &b, - Location loc) { - // left = mid + 1 - Value newLeft = b.create( - loc, midPointer, cstOne); - - b.create( - loc, ValueRange{newLeft, right}); - }; - auto elseBlock = [&](OpBuilder &b, - Location loc) { - // right = mid - b.create( - loc, ValueRange{left, midPointer}); - }; - - Value cumProb = b.create( - loc, cdf, ValueRange{midIndex}); - auto cmpPredicate = - arith::CmpFPredicateAttr::get( - b.getContext(), - arith::CmpFPredicate::OLT); - Value branchCondition = b.create( - loc, cmpPredicate, cumProb, uniformSample); - ValueRange branchResults = - b.create(loc, branchCondition, - thenBlock, elseBlock) - .getResults(); - Value newLeft = branchResults[0]; - Value newRight = branchResults[1]; - - b.create( - loc, ValueRange{newLeft, newRight}); - }) - .getResults(); - - // sample_idx = left_pointer - Value samplePointer = whileResults[0]; - Value iIndex = - b.create(loc, indexTy, i); - - Value prevResult = args[0]; - Value newResult; - if (inputRank == 1) { - // result[i] = sample_idx - newResult = b.create( - loc, samplePointer, prevResult, ValueRange{iIndex}); - } else { - // result[j][i] = sample_idx - newResult = b.create( - loc, samplePointer, prevResult, - ValueRange{jIndex, iIndex}); - } - - b.create(loc, ValueRange{newResult}); - }) + scf::ForOp::create( + rewriter, loc, cstZero, numSamples, cstOne, ValueRange{result}, + [&](OpBuilder &b, Location loc, Value i, ValueRange args) { + // Sample random float + Value uniformSample = + randomUniformF64(b, loc, i, key, min, max); + + // binary search in cdf to find our sample + Value left = arith::ConstantOp::create(b, loc, i64Ty, + b.getI64IntegerAttr(0)); + Value right = numCategories; + + auto checkCondition = [&](OpBuilder &b, Location loc, + ValueRange vals) { + Value left = vals[0]; + Value right = vals[1]; + + // while (right > left) + auto comparisonPredicate = arith::CmpIPredicateAttr::get( + b.getContext(), arith::CmpIPredicate::sgt); + Value loopCondition = arith::CmpIOp::create( + b, loc, comparisonPredicate, right, left); + scf::ConditionOp::create(b, loc, loopCondition, vals); + }; + + ValueRange whileResults = + scf::WhileOp::create( + b, loc, TypeRange{i64Ty, i64Ty}, + ValueRange{left, right}, checkCondition, + [&](OpBuilder &b, Location loc, ValueRange vals) { + Value left = vals[0]; + Value right = vals[1]; + + Value two = arith::ConstantOp::create( + b, loc, i64Ty, b.getI64IntegerAttr(2)); + Value diff = + arith::SubIOp::create(b, loc, right, left); + Value diffMid = + arith::DivSIOp::create(b, loc, diff, two); + Value midPointer = + arith::AddIOp::create(b, loc, left, diffMid); + Type indexTy = b.getIndexType(); + Value midIndex = arith::IndexCastOp::create( + b, loc, indexTy, midPointer); + + // branch and update search indices + auto thenBlock = [&](OpBuilder &b, Location loc) { + // left = mid + 1 + Value newLeft = arith::AddIOp::create( + b, loc, midPointer, cstOne); + + scf::YieldOp::create(b, loc, + ValueRange{newLeft, right}); + }; + auto elseBlock = [&](OpBuilder &b, Location loc) { + // right = mid + scf::YieldOp::create(b, loc, + ValueRange{left, midPointer}); + }; + + Value cumProb = tensor::ExtractOp::create( + b, loc, cdf, ValueRange{midIndex}); + auto cmpPredicate = arith::CmpFPredicateAttr::get( + b.getContext(), arith::CmpFPredicate::OLT); + Value branchCondition = arith::CmpFOp::create( + b, loc, cmpPredicate, cumProb, uniformSample); + ValueRange branchResults = + scf::IfOp::create(b, loc, branchCondition, + thenBlock, elseBlock) + .getResults(); + Value newLeft = branchResults[0]; + Value newRight = branchResults[1]; + + scf::YieldOp::create(b, loc, + ValueRange{newLeft, newRight}); + }) + .getResults(); + + // sample_idx = left_pointer + Value samplePointer = whileResults[0]; + Value iIndex = arith::IndexCastOp::create(b, loc, indexTy, i); + + Value prevResult = args[0]; + Value newResult; + if (inputRank == 1) { + // result[i] = sample_idx + newResult = tensor::InsertOp::create( + b, loc, samplePointer, prevResult, ValueRange{iIndex}); + } else { + // result[j][i] = sample_idx + newResult = tensor::InsertOp::create( + b, loc, samplePointer, prevResult, + ValueRange{jIndex, iIndex}); + } + + scf::YieldOp::create(b, loc, ValueRange{newResult}); + }) .getResult(0); - b.create(loc, ValueRange{finalResult}); + scf::YieldOp::create(b, loc, ValueRange{finalResult}); }; Value finalResultTensor = - rewriter - .create(loc, cstZero, numDistributions, cstOne, - ValueRange{resultTensor}, - multinomialComputation) + scf::ForOp::create(rewriter, loc, cstZero, numDistributions, cstOne, + ValueRange{resultTensor}, multinomialComputation) .getResult(0); Type newResultType = getTypeConverter()->convertType(op.getType()); diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index e3635fbbd095..2e33b724e4c3 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -100,7 +100,7 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { SmallVector resultShape; for (int64_t i = 0; i < inputType.getRank(); i++) { if (dim != i) { - auto currentDimSize = rewriter.create(loc, input, i); + auto currentDimSize = tensor::DimOp::create(rewriter, loc, input, i); resultShape.push_back(currentDimSize); } } @@ -109,32 +109,34 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { createZeroInitTensor(rewriter, loc, resultShape, idxElementType); // Second fill the output buffer for the running max or min. - Value initTensorVal = rewriter.create( - loc, getAsOpFoldResult(resultShape), inElementType); + Value initTensorVal = tensor::EmptyOp::create( + rewriter, loc, getAsOpFoldResult(resultShape), inElementType); Value fillValue; if (isa(inElementType)) { - fillValue = rewriter.create( - loc, rewriter.getFloatAttr( - inElementType, - APFloat::getInf( - cast(inElementType).getFloatSemantics(), - /*Negative=*/isMax))); + fillValue = arith::ConstantOp::create( + rewriter, loc, + rewriter.getFloatAttr( + inElementType, + APFloat::getInf( + cast(inElementType).getFloatSemantics(), + /*Negative=*/isMax))); } else if (!isUnsigned) { auto width = cast(inElementType).getWidth(); auto init = isMax ? APSInt::getSignedMinValue(width) : APSInt::getSignedMaxValue(width); - fillValue = rewriter.create( - loc, rewriter.getIntegerAttr(inElementType, init)); + fillValue = arith::ConstantOp::create( + rewriter, loc, rewriter.getIntegerAttr(inElementType, init)); } else if (isUnsigned) { auto width = cast(inElementType).getWidth(); auto init = isMax ? APInt::getMinValue(width) : APInt::getMaxValue(width); - fillValue = rewriter.create( - loc, rewriter.getIntegerAttr(inElementType, init)); + fillValue = arith::ConstantOp::create( + rewriter, loc, rewriter.getIntegerAttr(inElementType, init)); } Value filledTensorVal = - rewriter.create(loc, fillValue, initTensorVal).result(); + linalg::FillOp::create(rewriter, loc, fillValue, initTensorVal) + .result(); SmallVector iteratorTypes( inputType.getRank(), utils::IteratorType::parallel); @@ -155,8 +157,8 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { auto maps = AffineMap::inferFromExprList({exprs, resultExprs, resultExprs}, rewriter.getContext()); - auto linalgOp = rewriter.create( - loc, + auto linalgOp = linalg::GenericOp::create( + rewriter, loc, ArrayRef({filledTensorVal.getType(), filledTensorIdx.getType()}), input, ValueRange({filledTensorVal, filledTensorIdx}), maps, iteratorTypes, @@ -166,62 +168,62 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { Value oldValue = blockArgs[1]; Value oldIndex = blockArgs[2]; - Value newIndex = rewriter.create( - nestedLoc, oldIndex.getType(), - rewriter.create(loc, dim)); + Value newIndex = arith::IndexCastOp::create( + rewriter, nestedLoc, oldIndex.getType(), + linalg::IndexOp::create(rewriter, loc, dim)); Value resultVal, predicate; if (isa(inElementType)) { arith::CmpFPredicate predType; if (isMax) { predType = arith::CmpFPredicate::OGT; - resultVal = rewriter.create( - nestedLoc, newValue, oldValue); + resultVal = arith::MaximumFOp::create(rewriter, nestedLoc, + newValue, oldValue); } else { predType = arith::CmpFPredicate::OLT; - resultVal = rewriter.create( - nestedLoc, newValue, oldValue); + resultVal = arith::MinimumFOp::create(rewriter, nestedLoc, + newValue, oldValue); } - predicate = rewriter.create(nestedLoc, predType, - newValue, oldValue); + predicate = arith::CmpFOp::create(rewriter, nestedLoc, predType, + newValue, oldValue); } else { arith::CmpIPredicate predType; if (isMax) { predType = isUnsigned ? arith::CmpIPredicate::ugt : arith::CmpIPredicate::sgt; if (isUnsigned) { - resultVal = rewriter.create(nestedLoc, newValue, - oldValue); + resultVal = arith::MaxUIOp::create(rewriter, nestedLoc, + newValue, oldValue); } else { - resultVal = rewriter.create(nestedLoc, newValue, - oldValue); + resultVal = arith::MaxSIOp::create(rewriter, nestedLoc, + newValue, oldValue); } } else { predType = isUnsigned ? arith::CmpIPredicate::ult : arith::CmpIPredicate::slt; if (isUnsigned) { - resultVal = rewriter.create(nestedLoc, newValue, - oldValue); + resultVal = arith::MinUIOp::create(rewriter, nestedLoc, + newValue, oldValue); } else { - resultVal = rewriter.create(nestedLoc, newValue, - oldValue); + resultVal = arith::MinSIOp::create(rewriter, nestedLoc, + newValue, oldValue); } } - predicate = rewriter.create(nestedLoc, predType, - newValue, oldValue); + predicate = arith::CmpIOp::create(rewriter, nestedLoc, predType, + newValue, oldValue); } - auto resultIndex = rewriter.create( - nestedLoc, predicate, newIndex, oldIndex); - nestedBuilder.create( - nestedLoc, ValueRange({resultVal, resultIndex})); + auto resultIndex = arith::SelectOp::create( + rewriter, nestedLoc, predicate, newIndex, oldIndex); + linalg::YieldOp::create(nestedBuilder, nestedLoc, + ValueRange({resultVal, resultIndex})); }); if (!keepDim) { - Value rVal = rewriter.create(loc, valResultType, - linalgOp.getResult(0)); - Value rIdx = rewriter.create(loc, idxResultType, - linalgOp.getResult(1)); + Value rVal = tensor::CastOp::create(rewriter, loc, valResultType, + linalgOp.getResult(0)); + Value rIdx = tensor::CastOp::create(rewriter, loc, idxResultType, + linalgOp.getResult(1)); llvm::SmallVector res{rVal, rIdx}; rewriter.replaceOp(op, res); return success(); @@ -237,10 +239,10 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { valShape.resize(valShape.size() - 1); idxShape.resize(idxShape.size() - 1); - Value rVal = rewriter.create( - loc, valResultType.clone(valShape), linalgOp.getResult(0)); - Value rIdx = rewriter.create( - loc, idxResultType.clone(idxShape), linalgOp.getResult(1)); + Value rVal = tensor::CastOp::create( + rewriter, loc, valResultType.clone(valShape), linalgOp.getResult(0)); + Value rIdx = tensor::CastOp::create( + rewriter, loc, idxResultType.clone(idxShape), linalgOp.getResult(1)); SmallVector reassociation(valShape.size()); if (reassociation.size() > 0) { @@ -261,11 +263,11 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { valShape[dim] = 1; idxShape[dim] = 1; - Value unsqueezeVal = rewriter.create( - loc, valResultType, rVal, reassociation); + Value unsqueezeVal = tensor::ExpandShapeOp::create( + rewriter, loc, valResultType, rVal, reassociation); - Value unsqueezeIdx = rewriter.create( - loc, idxResultType, rIdx, reassociation); + Value unsqueezeIdx = tensor::ExpandShapeOp::create( + rewriter, loc, idxResultType, rIdx, reassociation); llvm::SmallVector unsqueezes = {unsqueezeVal, unsqueezeIdx}; rewriter.replaceOp(op, unsqueezes); @@ -278,67 +280,73 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { static Value createAbsOpForNormOps(OpBuilder &b, Location loc, Value elem, Type resultElementType) { if (isa(elem.getType())) { - return b.create(loc, elem); + return complex::AbsOp::create(b, loc, elem); } Value self = convertScalarToDtype(b, loc, elem, resultElementType); - return b.create(loc, self); + return math::AbsFOp::create(b, loc, self); } static Value createInitElementForReduceOp(OpBuilder &b, Location loc, Operation *op, Type elementType) { if (isa(op)) - return b.create(loc, b.getZeroAttr(elementType)); + return arith::ConstantOp::create(b, loc, b.getZeroAttr(elementType)); if (isa(op)) { if (isa(elementType)) - return b.create(loc, b.getFloatAttr(elementType, 1.0)); + return arith::ConstantOp::create(b, loc, + b.getFloatAttr(elementType, 1.0)); else if (isa(elementType)) - return b.create(loc, b.getIntegerAttr(elementType, 1)); + return arith::ConstantOp::create(b, loc, + b.getIntegerAttr(elementType, 1)); } if (isa(op)) { if (isa(elementType)) - return b.create( - loc, b.getFloatAttr( - elementType, - APFloat::getInf( - cast(elementType).getFloatSemantics(), - /*Negative=*/true))); + return arith::ConstantOp::create( + b, loc, + b.getFloatAttr( + elementType, + APFloat::getInf( + cast(elementType).getFloatSemantics(), + /*Negative=*/true))); else if (isa(elementType) && elementType.getIntOrFloatBitWidth() != 8) - return b.create( - loc, b.getIntegerAttr(elementType, - APSInt::getSignedMinValue( - elementType.getIntOrFloatBitWidth()))); + return arith::ConstantOp::create( + b, loc, + b.getIntegerAttr( + elementType, + APSInt::getSignedMinValue(elementType.getIntOrFloatBitWidth()))); } if (isa(op)) { if (isa(elementType)) - return b.create( - loc, b.getFloatAttr( - elementType, - APFloat::getInf( - cast(elementType).getFloatSemantics(), - /*Negative=*/false))); + return arith::ConstantOp::create( + b, loc, + b.getFloatAttr( + elementType, + APFloat::getInf( + cast(elementType).getFloatSemantics(), + /*Negative=*/false))); else if (isa(elementType) && elementType.getIntOrFloatBitWidth() != 8) - return b.create( - loc, b.getIntegerAttr(elementType, - APSInt::getSignedMaxValue( - elementType.getIntOrFloatBitWidth()))); + return arith::ConstantOp::create( + b, loc, + b.getIntegerAttr( + elementType, + APSInt::getSignedMaxValue(elementType.getIntOrFloatBitWidth()))); } if (isa(op) || isa(op) || isa(op)) - return b.create(loc, b.getZeroAttr(elementType)); + return arith::ConstantOp::create(b, loc, b.getZeroAttr(elementType)); if (isa(op)) { - return b.create(loc, b.getBoolAttr(true)); + return arith::ConstantOp::create(b, loc, b.getBoolAttr(true)); } if (isa(op)) { - return b.create(loc, b.getBoolAttr(false)); + return arith::ConstantOp::create(b, loc, b.getBoolAttr(false)); } op->emitError("unimplemented lowering in createInitElementForReduceOp"); @@ -355,44 +363,44 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc, convertScalarToDtype(b, loc, payloadArgs[0], resultElementType); Value result = payloadArgs[1]; if (isa(resultElementType)) - return b.create(loc, self, result); + return arith::AddFOp::create(b, loc, self, result); else if (isa(resultElementType)) - return b.create(loc, self, result); + return arith::AddIOp::create(b, loc, self, result); } else if (isa(op)) { Value self = convertScalarToDtype(b, loc, payloadArgs[0], resultElementType); Value result = payloadArgs[1]; if (isa(resultElementType)) - return b.create(loc, self, result); + return arith::MulFOp::create(b, loc, self, result); else if (isa(resultElementType)) - return b.create(loc, self, result); + return arith::MulIOp::create(b, loc, self, result); } else if (auto max = dyn_cast(op)) { Value self = convertScalarToDtype(b, loc, payloadArgs[0], resultElementType); Value result = payloadArgs[1]; if (isa(resultElementType)) - return b.create(loc, self, result); + return arith::MaximumFOp::create(b, loc, self, result); else if (isa(resultElementType)) { IntegerType intType = dyn_cast( cast(max.getSelf().getType()).getDtype()); if (intType.isUnsigned()) - return b.create(loc, self, result); + return arith::MaxUIOp::create(b, loc, self, result); if (intType.isSigned()) - return b.create(loc, self, result); + return arith::MaxSIOp::create(b, loc, self, result); } } else if (auto min = dyn_cast(op)) { Value self = convertScalarToDtype(b, loc, payloadArgs[0], resultElementType); Value result = payloadArgs[1]; if (isa(resultElementType)) - return b.create(loc, self, result); + return arith::MinimumFOp::create(b, loc, self, result); else if (isa(resultElementType)) { IntegerType intType = dyn_cast( cast(min.getSelf().getType()).getDtype()); if (intType.isUnsigned()) - return b.create(loc, self, result); + return arith::MinUIOp::create(b, loc, self, result); if (intType.isSigned()) - return b.create(loc, self, result); + return arith::MinSIOp::create(b, loc, self, result); } } else if (isa(op)) { // This creates payload for only the first of the two linalg.generic ops. @@ -404,8 +412,8 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc, Value p = convertScalarToDtype(b, loc, adaptor.getP(), resultElementType); auto abs = createAbsOpForNormOps(b, loc, elem, resultElementType); - auto pow = b.create(loc, abs, p); - return b.create(loc, pow, result); + auto pow = math::PowFOp::create(b, loc, abs, p); + return arith::AddFOp::create(b, loc, pow, result); } else if (isa(op)) { // This creates payload for only the first of the two linalg.generic ops. // TODO: Short-circuit operations if `ord` is zero or one. @@ -417,28 +425,28 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc, convertScalarToDtype(b, loc, adaptor.getOrd(), resultElementType); auto abs = createAbsOpForNormOps(b, loc, elem, resultElementType); - auto pow = b.create(loc, abs, ord); - return b.create(loc, pow, result); + auto pow = math::PowFOp::create(b, loc, abs, ord); + return arith::AddFOp::create(b, loc, pow, result); } else if (isa(op)) { Value elem = payloadArgs[0]; Value result = payloadArgs[1]; TypedAttr twoAttr = b.getFloatAttr(resultElementType, 2.0); - auto ord = b.create(loc, twoAttr); + auto ord = arith::ConstantOp::create(b, loc, twoAttr); auto abs = createAbsOpForNormOps(b, loc, elem, resultElementType); - auto pow = b.create(loc, abs, ord); - return b.create(loc, pow, result); + auto pow = math::PowFOp::create(b, loc, abs, ord); + return arith::AddFOp::create(b, loc, pow, result); } else if (isa(op)) { Value elem = payloadArgs[0]; Value result = payloadArgs[1]; Value self = convertScalarToDtype(b, loc, elem, resultElementType); - return b.create(loc, self, result); + return arith::AndIOp::create(b, loc, self, result); } else if (isa(op)) { Value elem = payloadArgs[0]; Value result = payloadArgs[1]; Value self = convertScalarToDtype(b, loc, elem, resultElementType); - return b.create(loc, self, result); + return arith::OrIOp::create(b, loc, self, result); } op->emitError("unimplemented lowering in createLinalgPayloadForReduceOp"); return nullptr; @@ -548,9 +556,9 @@ class ConvertReductionOp : public ConversionPattern { auto powBodyBuilder = [&](OpBuilder &builder, Location loc, ValueRange payloadArgs) { Value elem = convertScalarToDtype(builder, loc, payloadArgs[0], elemType); - auto result = builder.create(loc, elem, exponent); + auto result = math::PowFOp::create(builder, loc, elem, exponent); if (result) - builder.create(loc, Value{result}); + linalg::YieldOp::create(builder, loc, Value{result}); err = !result; }; @@ -580,9 +588,9 @@ class ConvertReductionOp : public ConversionPattern { // Raise each summed value to the inverse of the order of the norm. TypedAttr oneAttr = rewriter.getFloatAttr(elemType, 1.0); - auto oneValue = rewriter.create(loc, oneAttr); + auto oneValue = arith::ConstantOp::create(rewriter, loc, oneAttr); auto inverseOrdValue = - rewriter.create(loc, oneValue, ordValue); + arith::DivFOp::create(rewriter, loc, oneValue, ordValue); // Use the results of the first reduction operation from above to generate // a second reduction operation. @@ -607,7 +615,7 @@ class ConvertReductionOp : public ConversionPattern { Value result = createLinalgPayloadForReduceOp(builder, loc, payloadArgs, op, operands, elemType); if (result) - builder.create(loc, result); + linalg::YieldOp::create(builder, loc, result); err = !result; }; @@ -691,7 +699,7 @@ class ConvertReductionOp : public ConversionPattern { // the final result if (auto normOp = dyn_cast(op)) { auto halfAttr = rewriter.getFloatAttr(elemType, 0.5); - auto exp = rewriter.create(loc, halfAttr); + auto exp = arith::ConstantOp::create(rewriter, loc, halfAttr); reduceOp = createElementwiseExp(loc, elemType, exp, reduceOp, *opInfo, rewriter); } diff --git a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp index e5df4dd72c12..eb2b2d813027 100644 --- a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp +++ b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp @@ -37,8 +37,8 @@ static Value extractSlice(ConversionPatternRewriter &rewriter, Location loc, SmallVector offsets(inputRank, rewriter.getIndexAttr(0)); if (sliceLoc == END) { Value dimSize = inputShape[dimension]; - Value one = rewriter.create(loc, 1); - Value endIdx = rewriter.create(loc, dimSize, one); + Value one = arith::ConstantIndexOp::create(rewriter, loc, 1); + Value endIdx = arith::SubIOp::create(rewriter, loc, dimSize, one); offsets[dimension] = getAsOpFoldResult(endIdx); } @@ -48,8 +48,8 @@ static Value extractSlice(ConversionPatternRewriter &rewriter, Location loc, sizes[i] = (i == dimension) ? rewriter.getIndexAttr(1) : getAsOpFoldResult(inputShape[i]); - Value extractedSlice = rewriter.create( - loc, input, offsets, sizes, allOneStrides); + Value extractedSlice = tensor::ExtractSliceOp::create( + rewriter, loc, input, offsets, sizes, allOneStrides); return extractedSlice; } @@ -58,7 +58,7 @@ static Value createTile(ConversionPatternRewriter &rewriter, Location loc, SmallVector slices(tileWidth, slice); if (tileWidth == 1) return slice; - return rewriter.create(loc, dimension, slices); + return tensor::ConcatOp::create(rewriter, loc, dimension, slices); } static Value replicationPad(ConversionPatternRewriter &rewriter, Location loc, @@ -89,7 +89,7 @@ static Value replicationPad(ConversionPatternRewriter &rewriter, Location loc, } if (resultParts.size() > 1) - res = rewriter.create(loc, dim, resultParts); + res = tensor::ConcatOp::create(rewriter, loc, dim, resultParts); } return res; } @@ -141,8 +141,8 @@ class ConvertAtenConstantPadNdOp if (!matchPattern(lowv, m_TorchConstantInt(&lowi))) { Type cty = tc->convertType(lowv.getType()); lowv = tc->materializeTargetConversion(rewriter, loc, cty, lowv); - lowv = rewriter.create(loc, rewriter.getIndexType(), - lowv); + lowv = arith::IndexCastOp::create(rewriter, loc, + rewriter.getIndexType(), lowv); lowPad.push_back(lowv); staticLow.push_back(ShapedType::kDynamic); } else { @@ -153,8 +153,8 @@ class ConvertAtenConstantPadNdOp if (!matchPattern(highv, m_TorchConstantInt(&highi))) { Type cty = tc->convertType(highv.getType()); highv = tc->materializeTargetConversion(rewriter, loc, cty, highv); - highv = rewriter.create( - loc, rewriter.getIndexType(), highv); + highv = arith::IndexCastOp::create(rewriter, loc, + rewriter.getIndexType(), highv); highPad.push_back(highv); staticHigh.push_back(ShapedType::kDynamic); } else { @@ -174,8 +174,8 @@ class ConvertAtenConstantPadNdOp Type padType = tensor::PadOp::inferResultType( cast(self.getType()), staticLow, staticHigh); - Value paddedInput = rewriter.create( - loc, padType, self, lowPad, highPad, castedValue); + Value paddedInput = tensor::PadOp::create(rewriter, loc, padType, self, + lowPad, highPad, castedValue); rewriter.replaceOpWithNewOp(op, newResultType, paddedInput); return success(); } @@ -471,8 +471,8 @@ class ConvertAtenEmptyMemoryFormatOp } // Create an uninitialized tensor of `resultSize` shape. - Value initTensor = rewriter.create( - loc, getAsOpFoldResult(resultSizeIndex), resultElementType); + Value initTensor = tensor::EmptyOp::create( + rewriter, loc, getAsOpFoldResult(resultSizeIndex), resultElementType); rewriter.replaceOpWithNewOp(op, resultType, initTensor); return success(); } @@ -523,46 +523,45 @@ class ConvertAtenArangeStartStepOp // ceil((end - start)/step) Value resultShape; if (isa(dtype)) { - Value subOut = rewriter.create(loc, end, start); - resultShape = rewriter.create(loc, subOut, step); + Value subOut = arith::SubIOp::create(rewriter, loc, end, start); + resultShape = arith::CeilDivSIOp::create(rewriter, loc, subOut, step); } else { - Value subOut = rewriter.create(loc, end, start); - Value divOut = rewriter.create(loc, subOut, step); - Value ceilOut = rewriter.create(loc, divOut); - resultShape = - rewriter.create(loc, rewriter.getI64Type(), ceilOut); + Value subOut = arith::SubFOp::create(rewriter, loc, end, start); + Value divOut = arith::DivFOp::create(rewriter, loc, subOut, step); + Value ceilOut = math::CeilOp::create(rewriter, loc, divOut); + resultShape = arith::FPToUIOp::create(rewriter, loc, + rewriter.getI64Type(), ceilOut); } resultShape = castIntToIndex(rewriter, loc, resultShape); - Value resultTensor = rewriter.create( - loc, getAsOpFoldResult(resultShape), dtype); + Value resultTensor = tensor::EmptyOp::create( + rewriter, loc, getAsOpFoldResult(resultShape), dtype); auto iteratorType = utils::IteratorType::parallel; AffineMap indexingMap = AffineMap::getMultiDimIdentityMap(1, op->getContext()); Value finalRes = - rewriter - .create( - loc, /*resultTensorTypes=*/resultTensor.getType(), - /*inputs=*/ValueRange({}), - /*outputs=*/resultTensor, - /*indexingMaps=*/indexingMap, - /*iteratorTypes=*/iteratorType, - [&](OpBuilder &b, Location loc, ValueRange payloadArgs) { - Value index = b.create(loc, 0); - index = castIndexToInt64(b, loc, index); - index = convertScalarToDtype(b, loc, index, dtype); - Value mulOut, result; - if (isa(dtype)) { - mulOut = b.create(loc, step, index); - result = b.create(loc, start, mulOut); - } else { - mulOut = b.create(loc, step, index); - result = b.create(loc, start, mulOut); - } - b.create(loc, result); - }) + linalg::GenericOp::create( + rewriter, loc, /*resultTensorTypes=*/resultTensor.getType(), + /*inputs=*/ValueRange({}), + /*outputs=*/resultTensor, + /*indexingMaps=*/indexingMap, + /*iteratorTypes=*/iteratorType, + [&](OpBuilder &b, Location loc, ValueRange payloadArgs) { + Value index = linalg::IndexOp::create(b, loc, 0); + index = castIndexToInt64(b, loc, index); + index = convertScalarToDtype(b, loc, index, dtype); + Value mulOut, result; + if (isa(dtype)) { + mulOut = arith::MulFOp::create(b, loc, step, index); + result = arith::AddFOp::create(b, loc, start, mulOut); + } else { + mulOut = arith::MulIOp::create(b, loc, step, index); + result = arith::AddIOp::create(b, loc, start, mulOut); + } + linalg::YieldOp::create(b, loc, result); + }) .getResult(0); rewriter.replaceOpWithNewOp(op, resultType, finalRes); return success(); diff --git a/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp b/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp index ab5fec18f9b2..35c39c49c21d 100644 --- a/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp +++ b/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp @@ -33,14 +33,15 @@ class ConvertAtenSizeIntOp : public OpConversionPattern { Value self = adaptor.getSelf(); Value dim = adaptor.getDim(); auto type = cast(self.getType()); - Value inputRank = rewriter.create( - loc, rewriter.getI64IntegerAttr(type.getRank())); + Value inputRank = arith::ConstantOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(type.getRank())); Value dimPositive = toPositiveDimDynamic(rewriter, loc, dim, inputRank); if (!isAssumingStrictSymbolicShapes(rewriter)) { assertIsValidDim(rewriter, loc, dimPositive, inputRank); } - Value size = rewriter.create( - loc, adaptor.getSelf(), castIntToIndex(rewriter, loc, dimPositive)); + Value size = + tensor::DimOp::create(rewriter, loc, adaptor.getSelf(), + castIntToIndex(rewriter, loc, dimPositive)); rewriter.replaceOp(op, castIndexToInt64(rewriter, loc, size)); return success(); } @@ -86,15 +87,15 @@ class ConvertAtenTensorToScalarLikeOp : public OpConversionPattern { // `input` is a zero rank tensor or all the dimensions of the `input` tensor // are unit. Value constantOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + arith::ConstantOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); for (int64_t i = 0; i < inputRank; i++) checkDimEqualHelper(rewriter, loc, inputSizes[i], constantOne); // Extract the only element from the `input` tensor. Value constantZero = - rewriter.create(loc, rewriter.getIndexAttr(0)); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(0)); SmallVector indices(inputRank, constantZero); - Value result = rewriter.create(loc, input, indices); + Value result = tensor::ExtractOp::create(rewriter, loc, input, indices); Type resultType = this->getTypeConverter()->convertType(op->getResult(0).getType()); rewriter.replaceOp(op, convertScalarToDtype(rewriter, loc, result, @@ -221,16 +222,16 @@ class ConvertAtenFullOp : public OpConversionPattern { if (full.getType() != resultTy.getElementType()) { if (isa(full.getType())) { - full = rewriter.create(loc, resultTy.getElementType(), - full); + full = arith::TruncFOp::create(rewriter, loc, resultTy.getElementType(), + full); } else if (isa(full.getType())) { - full = rewriter.create(loc, resultTy.getElementType(), - full); + full = arith::TruncIOp::create(rewriter, loc, resultTy.getElementType(), + full); } } - Value outTensor = rewriter.create( - loc, filteredShape, resultTy.getElementType()); + Value outTensor = tensor::EmptyOp::create(rewriter, loc, filteredShape, + resultTy.getElementType()); rewriter.replaceOpWithNewOp(op, full, outTensor); diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index e89056355785..dfe4c431600b 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -45,14 +45,14 @@ template (type)) - return b.create(loc, fpred, lhs, rhs); + return arith::CmpFOp::create(b, loc, fpred, lhs, rhs); if (IntegerType intType = dyn_cast(type)) { if (intType.isUnsigned()) - return b.create(loc, iupred, lhs, rhs); + return arith::CmpIOp::create(b, loc, iupred, lhs, rhs); if (intType.isSigned()) - return b.create(loc, ispred, lhs, rhs); + return arith::CmpIOp::create(b, loc, ispred, lhs, rhs); assert(intType.getWidth() == 1); - return b.create(loc, iupred, lhs, rhs); + return arith::CmpIOp::create(b, loc, iupred, lhs, rhs); } llvm_unreachable("Unhandled element type for comparison"); } @@ -116,23 +116,24 @@ static Value createNotEqual(OpBuilder &b, Location loc, Type elementalType, static Value buildNormalCdf(OpBuilder &b, Location &loc, Value x, Value mean, Value sigma) { Type elementType = x.getType(); - Value xMinusMean = b.create(loc, x, mean); - Value two = b.create(loc, FloatAttr::get(elementType, 2)); - Value sqrt2 = b.create(loc, two); - Value erfArg = b.create(loc, xMinusMean, sqrt2); - Value erf = b.create(loc, erfArg); - Value one = b.create(loc, FloatAttr::get(elementType, 1)); - Value erfPlus1 = b.create(loc, one, erf); + Value xMinusMean = arith::SubFOp::create(b, loc, x, mean); + Value two = arith::ConstantOp::create(b, loc, FloatAttr::get(elementType, 2)); + Value sqrt2 = math::SqrtOp::create(b, loc, two); + Value erfArg = arith::DivFOp::create(b, loc, xMinusMean, sqrt2); + Value erf = math::ErfOp::create(b, loc, erfArg); + Value one = arith::ConstantOp::create(b, loc, FloatAttr::get(elementType, 1)); + Value erfPlus1 = arith::AddFOp::create(b, loc, one, erf); Value oneHalf = - b.create(loc, FloatAttr::get(elementType, 0.5)); - Value normalCdf = b.create(loc, oneHalf, erfPlus1); + arith::ConstantOp::create(b, loc, FloatAttr::get(elementType, 0.5)); + Value normalCdf = arith::MulFOp::create(b, loc, oneHalf, erfPlus1); return normalCdf; } static Value buildUnitNormalCdf(OpBuilder &b, Location &loc, Value x) { Type elementType = x.getType(); - Value zero = b.create(loc, FloatAttr::get(elementType, 0)); - Value one = b.create(loc, FloatAttr::get(elementType, 1)); + Value zero = + arith::ConstantOp::create(b, loc, FloatAttr::get(elementType, 0)); + Value one = arith::ConstantOp::create(b, loc, FloatAttr::get(elementType, 1)); return buildNormalCdf(b, loc, x, zero, one); } @@ -149,7 +150,7 @@ static Value createFpOpWithDtype(OpBuilder &b, const TypeConverter *converter, computeTy = b.getF32Type(); Location loc = op->getLoc(); Value arg = convertScalarToDtype(b, loc, payloadArg, computeTy, inTTy); - auto newOp = b.create(loc, arg); + auto newOp = MathOpTy::create(b, loc, arg); return convertScalarToDtype(b, loc, newOp, outTy, std::nullopt, outTTy); } @@ -221,20 +222,20 @@ createTriangularMatrix(OpBuilder &b, Location loc, ValueRange payloadArgs, uint64_t inputRank = inputType.getRank(); // Use the indices of the two innermost dimensions. - auto rowIndex = b.create(loc, inputRank - 2); + auto rowIndex = linalg::IndexOp::create(b, loc, inputRank - 2); Value rowIndexI64 = castIndexToInt64(b, loc, rowIndex); - auto colIndex = b.create(loc, inputRank - 1); + auto colIndex = linalg::IndexOp::create(b, loc, inputRank - 1); Value colIndexI64 = castIndexToInt64(b, loc, colIndex); // columnIndex >= rowIndex + diagonal? auto sum = - b.create(loc, rowIndexI64, /*diagonal=*/operands[1]); - auto pred = b.create(loc, predicate, colIndexI64, sum); + arith::AddIOp::create(b, loc, rowIndexI64, /*diagonal=*/operands[1]); + auto pred = arith::CmpIOp::create(b, loc, predicate, colIndexI64, sum); Value scalar = payloadArgs[0]; Type elementType = inputType.getElementType(); Value zero = getConstant(b, loc, 0, elementType); - result = b.create(loc, pred, scalar, zero); + result = arith::SelectOp::create(b, loc, pred, scalar, zero); return success(); } @@ -257,13 +258,13 @@ Value createDivModePayload(OpBuilder &b, Location loc, Value quotient; if (isa(dtype)) { - quotient = b.create(loc, lhs, rhs); + quotient = arith::DivFOp::create(b, loc, lhs, rhs); } else if (dtype.isUnsignedInteger()) { - quotient = b.create(loc, lhs, rhs); + quotient = arith::DivUIOp::create(b, loc, lhs, rhs); } else { assert(dtype.isInteger() && "dtype should be an integer (signless or signed)"); - quotient = b.create(loc, lhs, rhs); + quotient = arith::DivSIOp::create(b, loc, lhs, rhs); } if (isa(op.getRoundingMode().getType())) @@ -285,24 +286,24 @@ Value createDivModePayload(OpBuilder &b, Location loc, } // float - Value ceil = b.create(loc, quotient); - Value floor = b.create(loc, quotient); - Value cstZero = b.create(loc, b.getZeroAttr(dtype)); - Value pred = b.create(loc, arith::CmpFPredicate::ULT, - quotient, cstZero); - return b.create(loc, pred, ceil, floor); + Value ceil = math::CeilOp::create(b, loc, quotient); + Value floor = math::FloorOp::create(b, loc, quotient); + Value cstZero = arith::ConstantOp::create(b, loc, b.getZeroAttr(dtype)); + Value pred = arith::CmpFOp::create(b, loc, arith::CmpFPredicate::ULT, + quotient, cstZero); + return arith::SelectOp::create(b, loc, pred, ceil, floor); } if (roundingMode == "floor") { // "floor" - rounds the results of the division down. Equivalent to // floor division in Python (the // operator) if (isa(dtype)) - return b.create(loc, quotient); + return math::FloorOp::create(b, loc, quotient); if (!dtype.isUnsignedInteger()) { Type defaultIntToFloatType = b.getF64Type(); lhs = convertScalarToDtype(b, loc, lhs, defaultIntToFloatType); rhs = convertScalarToDtype(b, loc, rhs, defaultIntToFloatType); - quotient = b.create(loc, lhs, rhs); - Value floor = b.create(loc, quotient); + quotient = arith::DivFOp::create(b, loc, lhs, rhs); + Value floor = math::FloorOp::create(b, loc, quotient); Value convert = convertScalarToDtype(b, loc, floor, dtype); return convert; } @@ -335,41 +336,41 @@ Value createRemainderPayload(OpBuilder &b, Location loc, // https://github.com/python/cpython/blob/2afd1751dd9a35d4ec03b708e3e5cddd72c43f7e/Objects/longobject.c#L3662 Value result; if (isa(dtype)) { - Value remainder = b.create(loc, lhs, rhs); + Value remainder = arith::RemFOp::create(b, loc, lhs, rhs); - Value zero = b.create(loc, b.getZeroAttr(dtype)); - Value remainderNotEqualToZero = b.create( - loc, arith::CmpFPredicate::ONE, remainder, zero); + Value zero = arith::ConstantOp::create(b, loc, b.getZeroAttr(dtype)); + Value remainderNotEqualToZero = arith::CmpFOp::create( + b, loc, arith::CmpFPredicate::ONE, remainder, zero); Value otherLessThanZero = - b.create(loc, arith::CmpFPredicate::OLT, rhs, zero); - Value remainderLessThanZero = b.create( - loc, arith::CmpFPredicate::OLT, remainder, zero); + arith::CmpFOp::create(b, loc, arith::CmpFPredicate::OLT, rhs, zero); + Value remainderLessThanZero = arith::CmpFOp::create( + b, loc, arith::CmpFPredicate::OLT, remainder, zero); Value xorCondition = - b.create(loc, otherLessThanZero, remainderLessThanZero); + arith::XOrIOp::create(b, loc, otherLessThanZero, remainderLessThanZero); Value condition = - b.create(loc, remainderNotEqualToZero, xorCondition); - Value fixedRemainder = b.create(loc, remainder, rhs); + arith::AndIOp::create(b, loc, remainderNotEqualToZero, xorCondition); + Value fixedRemainder = arith::AddFOp::create(b, loc, remainder, rhs); result = - b.create(loc, condition, fixedRemainder, remainder); + arith::SelectOp::create(b, loc, condition, fixedRemainder, remainder); } else { assert(dtype.isInteger() && "dtype should be a float or integer (signless or signed)"); - Value remainder = b.create(loc, lhs, rhs); + Value remainder = arith::RemSIOp::create(b, loc, lhs, rhs); - Value zero = b.create(loc, b.getZeroAttr(dtype)); - Value remainderNotEqualToZero = - b.create(loc, arith::CmpIPredicate::ne, remainder, zero); + Value zero = arith::ConstantOp::create(b, loc, b.getZeroAttr(dtype)); + Value remainderNotEqualToZero = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::ne, remainder, zero); Value otherLessThanZero = - b.create(loc, arith::CmpIPredicate::slt, rhs, zero); - Value remainderLessThanZero = b.create( - loc, arith::CmpIPredicate::slt, remainder, zero); + arith::CmpIOp::create(b, loc, arith::CmpIPredicate::slt, rhs, zero); + Value remainderLessThanZero = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::slt, remainder, zero); Value xorCondition = - b.create(loc, otherLessThanZero, remainderLessThanZero); + arith::XOrIOp::create(b, loc, otherLessThanZero, remainderLessThanZero); Value condition = - b.create(loc, remainderNotEqualToZero, xorCondition); - Value fixedRemainder = b.create(loc, remainder, rhs); + arith::AndIOp::create(b, loc, remainderNotEqualToZero, xorCondition); + Value fixedRemainder = arith::AddIOp::create(b, loc, remainder, rhs); result = - b.create(loc, condition, fixedRemainder, remainder); + arith::SelectOp::create(b, loc, condition, fixedRemainder, remainder); } return result; } @@ -378,9 +379,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( OpBuilder &b, Location loc, const TypeConverter *converter, ValueRange payloadArgs, Operation *op, ArrayRef operands) { if (isa(op)) - return b.create(loc, payloadArgs[0]); + return math::FloorOp::create(b, loc, payloadArgs[0]); if (isa(op)) - return b.create(loc, payloadArgs[0]); + return math::CeilOp::create(b, loc, payloadArgs[0]); if (isa(op)) { return createFpOpWithDtype(b, converter, payloadArgs[0], op); } @@ -472,7 +473,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); - return b.create(loc, lhs, rhs); + return arith::AndIOp::create(b, loc, lhs, rhs); } if (auto bitwiseAndScalar = dyn_cast(op)) { Type dtype = cast( @@ -491,7 +492,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value other = convertScalarToDtype(b, loc, operands[1], dtype, /*srcOriginalDtype=*/std::nullopt, /*dstOriginalDtype=*/resultElementType); - return b.create(loc, self, other); + return arith::AndIOp::create(b, loc, self, other); } if (auto bitwiseOrTensor = dyn_cast(op)) { if (isa( @@ -505,7 +506,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); - return b.create(loc, lhs, rhs); + return arith::OrIOp::create(b, loc, lhs, rhs); } if (auto bitwiseXorTensor = dyn_cast(op)) { if (isa( @@ -519,7 +520,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); - return b.create(loc, lhs, rhs); + return arith::XOrIOp::create(b, loc, lhs, rhs); } if (auto bitwiseRightShiftTensor = dyn_cast(op)) { @@ -533,7 +534,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); - return b.create(loc, lhs, rhs); + return arith::ShRSIOp::create(b, loc, lhs, rhs); } if (auto bitwiseLeftShiftTensor = dyn_cast(op)) { @@ -547,7 +548,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); - return b.create(loc, lhs, rhs); + return arith::ShLIOp::create(b, loc, lhs, rhs); } if (isa(op)) { MLIRContext *context = op->getContext(); @@ -555,17 +556,17 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], floatDtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], floatDtype); Value zero = - b.create(loc, b.getFloatAttr(floatDtype, 0)); + arith::ConstantOp::create(b, loc, b.getFloatAttr(floatDtype, 0)); Value lhsTest = createNotEqual(b, loc, floatDtype, lhs, zero); Value rhsTest = createNotEqual(b, loc, floatDtype, rhs, zero); if (isa(op)) { - return b.create(loc, lhsTest, rhsTest); + return arith::OrIOp::create(b, loc, lhsTest, rhsTest); } if (isa(op)) { - return b.create(loc, lhsTest, rhsTest); + return arith::AndIOp::create(b, loc, lhsTest, rhsTest); } if (isa(op)) { - return b.create(loc, lhsTest, rhsTest); + return arith::XOrIOp::create(b, loc, lhsTest, rhsTest); } llvm_unreachable("Unknown op type"); } @@ -574,7 +575,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Type floatDtype = mlir::Float64Type::get(context); Value self = convertScalarToDtype(b, loc, payloadArgs[0], floatDtype); Value zero = - b.create(loc, b.getFloatAttr(floatDtype, 0)); + arith::ConstantOp::create(b, loc, b.getFloatAttr(floatDtype, 0)); return createEqual(b, loc, floatDtype, self, zero); } if (auto complex = dyn_cast(op)) { @@ -585,17 +586,17 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], stype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], stype); - return b.create(loc, ctype, lhs, rhs); + return complex::CreateOp::create(b, loc, ctype, lhs, rhs); } if (isa(op)) { if (isa(payloadArgs[0].getType())) - return b.create(loc, payloadArgs[0]); - return b.create(loc, payloadArgs[0]); + return math::AbsIOp::create(b, loc, payloadArgs[0]); + return math::AbsFOp::create(b, loc, payloadArgs[0]); } if (isa(op)) { - Value abs = b.create(loc, payloadArgs[0]); - Value infinity = b.create( - loc, + Value abs = math::AbsFOp::create(b, loc, payloadArgs[0]); + Value infinity = arith::ConstantOp::create( + b, loc, b.getFloatAttr(abs.getType(), std::numeric_limits::infinity())); return createEqual(b, loc, abs.getType(), abs, infinity); } @@ -611,12 +612,12 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value arg = payloadArgs[0]; arg = convertScalarToDtype(b, loc, payloadArgs[0], computeTy, inTTy); - auto negate = b.create(loc, arg); + auto negate = arith::NegFOp::create(b, loc, arg); auto one = - b.create(loc, FloatAttr::get(negate.getType(), 1)); - auto exp = b.create(loc, negate); - auto added = b.create(loc, exp, one); - auto div = b.create(loc, one, added); + arith::ConstantOp::create(b, loc, FloatAttr::get(negate.getType(), 1)); + auto exp = math::ExpOp::create(b, loc, negate); + auto added = arith::AddFOp::create(b, loc, exp, one); + auto div = arith::DivFOp::create(b, loc, one, added); return convertScalarToDtype(b, loc, div, outTy, std::nullopt, outTTy); } if (auto relu = dyn_cast(op)) { @@ -653,35 +654,35 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } zeroPoint = converter->materializeTargetConversion( b, loc, converter->convertType(zeroPoint.getType()), zeroPoint); - auto minForIntTypeValue = b.create( - loc, b.getIntegerAttr(zeroPoint.getType(), minForIntType)); - auto maxForIntTypeValue = b.create( - loc, b.getIntegerAttr(zeroPoint.getType(), maxForIntType)); - auto zpLtMax = b.create(loc, arith::CmpIPredicate::slt, - zeroPoint, maxForIntTypeValue); - b.create( - loc, zpLtMax, + auto minForIntTypeValue = arith::ConstantOp::create( + b, loc, b.getIntegerAttr(zeroPoint.getType(), minForIntType)); + auto maxForIntTypeValue = arith::ConstantOp::create( + b, loc, b.getIntegerAttr(zeroPoint.getType(), maxForIntType)); + auto zpLtMax = arith::CmpIOp::create(b, loc, arith::CmpIPredicate::slt, + zeroPoint, maxForIntTypeValue); + cf::AssertOp::create( + b, loc, zpLtMax, b.getStringAttr("Invalid Quantization: quantized relu with " "zero-point > max qint")); - auto zpLtMin = b.create(loc, arith::CmpIPredicate::slt, - zeroPoint, minForIntTypeValue); - zeroPoint = b.create(loc, zpLtMin, minForIntTypeValue, - zeroPoint); - zeroPoint = b.create(loc, arg.getType(), zeroPoint); + auto zpLtMin = arith::CmpIOp::create(b, loc, arith::CmpIPredicate::slt, + zeroPoint, minForIntTypeValue); + zeroPoint = arith::SelectOp::create(b, loc, zpLtMin, minForIntTypeValue, + zeroPoint); + zeroPoint = arith::TruncIOp::create(b, loc, arg.getType(), zeroPoint); } else { zeroPoint = - b.create(loc, b.getZeroAttr(arg.getType())); + arith::ConstantOp::create(b, loc, b.getZeroAttr(arg.getType())); } Value cmp; if (intType) { auto pred = isUnsigned ? arith::CmpIPredicate::ugt : arith::CmpIPredicate::sgt; - cmp = b.create(loc, pred, arg, zeroPoint); + cmp = arith::CmpIOp::create(b, loc, pred, arg, zeroPoint); } else { - cmp = b.create(loc, arith::CmpFPredicate::UGT, arg, - zeroPoint); + cmp = arith::CmpFOp::create(b, loc, arith::CmpFPredicate::UGT, arg, + zeroPoint); } - return b.create(loc, cmp, arg, zeroPoint); + return arith::SelectOp::create(b, loc, cmp, arg, zeroPoint); } if (auto round = dyn_cast(op)) { if (!isa( @@ -689,7 +690,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( round.emitError("unimplemented: non-floating point dtype"); return nullptr; } - return b.create(loc, payloadArgs[0]); + return math::RoundEvenOp::create(b, loc, payloadArgs[0]); } if (auto prelu = dyn_cast(op)) { if (!isa( @@ -699,17 +700,17 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } Type elementType = payloadArgs[0].getType(); Value constZero = - b.create(loc, b.getZeroAttr(elementType)); - Value pred = b.create(loc, arith::CmpFPredicate::UGT, - payloadArgs[0], constZero); + arith::ConstantOp::create(b, loc, b.getZeroAttr(elementType)); + Value pred = arith::CmpFOp::create(b, loc, arith::CmpFPredicate::UGT, + payloadArgs[0], constZero); Value positivePart = - b.create(loc, pred, payloadArgs[0], constZero); + arith::SelectOp::create(b, loc, pred, payloadArgs[0], constZero); Value negativePart = - b.create(loc, pred, constZero, payloadArgs[0]); + arith::SelectOp::create(b, loc, pred, constZero, payloadArgs[0]); Value scale = convertScalarToDtype(b, loc, payloadArgs[1], elementType); Value scaledNegativePart = - b.create(loc, negativePart, scale); - return b.create(loc, positivePart, scaledNegativePart); + arith::MulFOp::create(b, loc, negativePart, scale); + return arith::AddFOp::create(b, loc, positivePart, scaledNegativePart); } if (auto gelu = dyn_cast(op)) { if (!isa( @@ -726,32 +727,32 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (approximate == "none") { Value multiplier = buildUnitNormalCdf(b, loc, payloadArgs[0]); - return b.create(loc, payloadArgs[0], multiplier); + return arith::MulFOp::create(b, loc, payloadArgs[0], multiplier); } if (approximate == "tanh") { // GELU(x)=0.5∗x∗(1+Tanh((2/Ï€)^1/2 * (x+0.044715∗x^3))) // Ref: https://pytorch.org/docs/stable/generated/torch.nn.GELU.html - Value cstThree = b.create( - loc, IntegerAttr::get(IntegerType::get(op->getContext(), 64), 3)); - Value xCube = b.create(loc, payloadArgs[0], cstThree); + Value cstThree = arith::ConstantOp::create( + b, loc, IntegerAttr::get(IntegerType::get(op->getContext(), 64), 3)); + Value xCube = math::FPowIOp::create(b, loc, payloadArgs[0], cstThree); Type elementType = payloadArgs[0].getType(); - Value cstAlpha = b.create( - loc, FloatAttr::get(elementType, 0.044715)); - Value xCubeMulAlpha = b.create(loc, xCube, cstAlpha); + Value cstAlpha = arith::ConstantOp::create( + b, loc, FloatAttr::get(elementType, 0.044715)); + Value xCubeMulAlpha = arith::MulFOp::create(b, loc, xCube, cstAlpha); Value xPlusXCubeMulAlpha = - b.create(loc, payloadArgs[0], xCubeMulAlpha); - Value cstBeta = b.create( - loc, FloatAttr::get(elementType, 0.7977240352174656)); + arith::AddFOp::create(b, loc, payloadArgs[0], xCubeMulAlpha); + Value cstBeta = arith::ConstantOp::create( + b, loc, FloatAttr::get(elementType, 0.7977240352174656)); Value betaMulX = - b.create(loc, cstBeta, xPlusXCubeMulAlpha); - Value tanh = b.create(loc, betaMulX); + arith::MulFOp::create(b, loc, cstBeta, xPlusXCubeMulAlpha); + Value tanh = math::TanhOp::create(b, loc, betaMulX); Value cstOne = - b.create(loc, FloatAttr::get(elementType, 1.0)); - Value onePlusTanh = b.create(loc, cstOne, tanh); + arith::ConstantOp::create(b, loc, FloatAttr::get(elementType, 1.0)); + Value onePlusTanh = arith::AddFOp::create(b, loc, cstOne, tanh); Value cstHalf = - b.create(loc, FloatAttr::get(elementType, 0.5)); - Value multiplier = b.create(loc, cstHalf, onePlusTanh); - return b.create(loc, payloadArgs[0], multiplier); + arith::ConstantOp::create(b, loc, FloatAttr::get(elementType, 0.5)); + Value multiplier = arith::MulFOp::create(b, loc, cstHalf, onePlusTanh); + return arith::MulFOp::create(b, loc, payloadArgs[0], multiplier); } gelu.emitError("unimplemented: approximate value should be none or tanh"); return nullptr; @@ -769,27 +770,27 @@ static Value createLinalgPayloadCalculationForElementwiseOp( approximate != "none") return nullptr; Type elementType = payloadArgs[1].getType(); - Value cstAlpha0 = b.create( - loc, FloatAttr::get(elementType, 1.12837916709551257390)); - Value cstAlpha1 = b.create( - loc, FloatAttr::get(elementType, 0.70710678118654752440)); + Value cstAlpha0 = arith::ConstantOp::create( + b, loc, FloatAttr::get(elementType, 1.12837916709551257390)); + Value cstAlpha1 = arith::ConstantOp::create( + b, loc, FloatAttr::get(elementType, 0.70710678118654752440)); Value oneHalf = - b.create(loc, FloatAttr::get(elementType, 0.5)); - Value kAlpha = b.create(loc, cstAlpha0, cstAlpha1); - Value kAlphaHalf = b.create(loc, kAlpha, oneHalf); + arith::ConstantOp::create(b, loc, FloatAttr::get(elementType, 0.5)); + Value kAlpha = arith::MulFOp::create(b, loc, cstAlpha0, cstAlpha1); + Value kAlphaHalf = arith::MulFOp::create(b, loc, kAlpha, oneHalf); Value negOneHalf = - b.create(loc, FloatAttr::get(elementType, -0.5)); + arith::ConstantOp::create(b, loc, FloatAttr::get(elementType, -0.5)); Value inputSquared = - b.create(loc, payloadArgs[1], payloadArgs[1]); + arith::MulFOp::create(b, loc, payloadArgs[1], payloadArgs[1]); Value negHalfInputSquared = - b.create(loc, inputSquared, negOneHalf); - Value dinput = b.create(loc, negHalfInputSquared); + arith::MulFOp::create(b, loc, inputSquared, negOneHalf); + Value dinput = math::ExpOp::create(b, loc, negHalfInputSquared); Value cdf = buildUnitNormalCdf(b, loc, payloadArgs[1]); - Value dinputInput = b.create(loc, dinput, payloadArgs[1]); + Value dinputInput = arith::MulFOp::create(b, loc, dinput, payloadArgs[1]); Value dinputInputAlpha = - b.create(loc, dinputInput, kAlphaHalf); - Value cdfExt = b.create(loc, dinputInputAlpha, cdf); - return b.create(loc, payloadArgs[0], cdfExt); + arith::MulFOp::create(b, loc, dinputInput, kAlphaHalf); + Value cdfExt = arith::AddFOp::create(b, loc, dinputInputAlpha, cdf); + return arith::MulFOp::create(b, loc, payloadArgs[0], cdfExt); } if (auto hardtanhBackward = dyn_cast(op)) { AtenHardtanhBackwardOp::Adaptor adaptor(operands); @@ -802,15 +803,15 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Type elementType = gradOutput.getType(); Value self = convertScalarToDtype(b, loc, payloadArgs[1], elementType); Value constantZero = - b.create(loc, FloatAttr::get(elementType, 0.0)); + arith::ConstantOp::create(b, loc, FloatAttr::get(elementType, 0.0)); Value min = convertScalarToDtype(b, loc, adaptor.getMinVal(), elementType); Value max = convertScalarToDtype(b, loc, adaptor.getMaxVal(), elementType); Value lesser = - b.create(loc, arith::CmpFPredicate::ULT, self, min); + arith::CmpFOp::create(b, loc, arith::CmpFPredicate::ULT, self, min); Value greater = - b.create(loc, arith::CmpFPredicate::UGT, self, max); - Value cmp = b.create(loc, lesser, greater); - return b.create(loc, cmp, constantZero, gradOutput); + arith::CmpFOp::create(b, loc, arith::CmpFPredicate::UGT, self, max); + Value cmp = arith::OrIOp::create(b, loc, lesser, greater); + return arith::SelectOp::create(b, loc, cmp, constantZero, gradOutput); } if (auto add = dyn_cast(op)) { AtenAddTensorOp::Adaptor adaptor(operands); @@ -827,14 +828,14 @@ static Value createLinalgPayloadCalculationForElementwiseOp( /*srcOriginalDtype=*/std::nullopt, /*dstOriginalDtype=*/resultElementType); if (isa(dtype)) { - Value scaled = b.create(loc, rhs, alpha); - return b.create(loc, lhs, scaled); + Value scaled = arith::MulFOp::create(b, loc, rhs, alpha); + return arith::AddFOp::create(b, loc, lhs, scaled); } else if (dtype.isInteger(1)) { - Value scaled = b.create(loc, rhs, alpha); - return b.create(loc, lhs, scaled); + Value scaled = arith::MulIOp::create(b, loc, rhs, alpha); + return arith::OrIOp::create(b, loc, lhs, scaled); } else { - Value scaled = b.create(loc, rhs, alpha); - return b.create(loc, lhs, scaled); + Value scaled = arith::MulIOp::create(b, loc, rhs, alpha); + return arith::AddIOp::create(b, loc, lhs, scaled); } } if (auto sub = dyn_cast(op)) { @@ -853,11 +854,11 @@ static Value createLinalgPayloadCalculationForElementwiseOp( /*dstOriginalDtype=*/resultElementType, /*originalScalar=*/sub.getAlpha()); if (isa(dtype)) { - Value scaled = b.create(loc, rhs, alpha); - return b.create(loc, lhs, scaled); + Value scaled = arith::MulFOp::create(b, loc, rhs, alpha); + return arith::SubFOp::create(b, loc, lhs, scaled); } else { - Value scaled = b.create(loc, rhs, alpha); - return b.create(loc, lhs, scaled); + Value scaled = arith::MulIOp::create(b, loc, rhs, alpha); + return arith::SubIOp::create(b, loc, lhs, scaled); } } if (auto lshiftScalar = dyn_cast(op)) { @@ -869,7 +870,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( convertScalarToDtype(b, loc, operands[1], dtype, /*srcOriginalDtype=*/operands[1].getType(), /*dstOriginalDtype=*/dtype); - return b.create(loc, self, other); + return arith::ShLIOp::create(b, loc, self, other); } if (auto rshiftScalar = dyn_cast(op)) { Type dtype = @@ -880,7 +881,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( convertScalarToDtype(b, loc, operands[1], dtype, /*srcOriginalDtype=*/operands[1].getType(), /*dstOriginalDtype=*/dtype); - return b.create(loc, self, other); + return arith::ShRUIOp::create(b, loc, self, other); } if (auto subScalar = dyn_cast(op)) { Type dtype = @@ -892,11 +893,11 @@ static Value createLinalgPayloadCalculationForElementwiseOp( b, loc, operands[2], dtype, /*srcOriginalDtype=*/operands[2].getType(), /*dstOriginalDtype=*/dtype); if (isa(dtype)) { - Value mult = b.create(loc, other, alpha); - return b.create(loc, self, mult); + Value mult = arith::MulFOp::create(b, loc, other, alpha); + return arith::SubFOp::create(b, loc, self, mult); } else if (isa(dtype)) { - Value mult = b.create(loc, other, alpha); - return b.create(loc, self, mult); + Value mult = arith::MulIOp::create(b, loc, other, alpha); + return arith::SubIOp::create(b, loc, self, mult); } subScalar.emitError("unimplemented: dtype other than float and integer " "types are not supported."); @@ -918,11 +919,11 @@ static Value createLinalgPayloadCalculationForElementwiseOp( /*srcOriginalDtype=*/std::nullopt, /*dstOriginalDtype=*/resultElementType); if (isa(dtype)) { - Value mult = b.create(loc, other, alpha); - return b.create(loc, self, mult); + Value mult = arith::MulFOp::create(b, loc, other, alpha); + return arith::AddFOp::create(b, loc, self, mult); } else if (isa(dtype)) { - Value mult = b.create(loc, other, alpha); - return b.create(loc, self, mult); + Value mult = arith::MulIOp::create(b, loc, other, alpha); + return arith::AddIOp::create(b, loc, self, mult); } addScalar.emitError("unimplemented: dtype other than float and integer " "types are not supported."); @@ -935,11 +936,11 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); if (isa(dtype)) { - return b.create(loc, lhs, rhs); + return arith::MulFOp::create(b, loc, lhs, rhs); } else if (isa(dtype)) { - return b.create(loc, lhs, rhs); + return complex::MulOp::create(b, loc, lhs, rhs); } else { - return b.create(loc, lhs, rhs); + return arith::MulIOp::create(b, loc, lhs, rhs); } } if (auto atan2 = dyn_cast(op)) { @@ -951,7 +952,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); - return b.create(loc, lhs, rhs); + return math::Atan2Op::create(b, loc, lhs, rhs); } if (auto ltTensor = dyn_cast(op)) { return createCompareOp(b, loc, ltTensor, payloadArgs[0], payloadArgs[1]); @@ -978,11 +979,11 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); if (isa(dtype)) - return b.create(loc, lhs, rhs); + return arith::DivFOp::create(b, loc, lhs, rhs); else if (isa(dtype)) { if (dtype.isUnsignedInteger()) - return b.create(loc, lhs, rhs); - return b.create(loc, lhs, rhs); + return arith::DivUIOp::create(b, loc, lhs, rhs); + return arith::DivSIOp::create(b, loc, lhs, rhs); } div.emitError("unimplemented: non-floating point and non-integer dtype"); return nullptr; @@ -1003,7 +1004,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } Value selfPromoted = convertScalarToDtype(b, loc, operands[0], dtype); Value expPromoted = convertScalarToDtype(b, loc, payloadArgs[0], dtype); - return b.create(loc, selfPromoted, expPromoted); + return math::PowFOp::create(b, loc, selfPromoted, expPromoted); } if (auto pow = dyn_cast(op)) { @@ -1019,11 +1020,11 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return nullptr; } if (isa(expType)) { - return b.create(loc, payloadArgs[0], exp); + return math::FPowIOp::create(b, loc, payloadArgs[0], exp); } Type dtype = cast(pow.getSelf().getType()).getDtype(); Value expPromoted = convertScalarToDtype(b, loc, operands[1], dtype); - return b.create(loc, payloadArgs[0], expPromoted); + return math::PowFOp::create(b, loc, payloadArgs[0], expPromoted); } if (auto pow = dyn_cast(op)) { @@ -1042,7 +1043,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( powType = mlir::Float64Type::get(op->getContext()); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], powType); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], powType); - auto powOp = b.create(loc, lhs, rhs); + auto powOp = math::PowFOp::create(b, loc, lhs, rhs); return convertScalarToDtype(b, loc, powOp, dtype); } @@ -1053,7 +1054,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( imag.emitError("unimplemented: non-floating point dtype"); return nullptr; } - Value imagVal = b.create(loc, payloadArgs[0]); + Value imagVal = complex::ImOp::create(b, loc, payloadArgs[0]); return imagVal; } @@ -1064,7 +1065,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( real.emitError("unimplemented: non-floating point dtype"); return nullptr; } - Value realVal = b.create(loc, payloadArgs[0]); + Value realVal = complex::ReOp::create(b, loc, payloadArgs[0]); return realVal; } @@ -1098,7 +1099,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[2], dtype); - return b.create(loc, payloadArgs[0], lhs, rhs); + return arith::SelectOp::create(b, loc, payloadArgs[0], lhs, rhs); } if (auto lerp = dyn_cast(op)) { @@ -1111,9 +1112,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( auto start = adaptor.getSelf(); auto end = adaptor.getEnd(); auto weight = adaptor.getWeight(); - auto delta = b.create(loc, end, start); - auto weightedDelta = b.create(loc, delta, weight); - return b.create(loc, start, weightedDelta); + auto delta = arith::SubFOp::create(b, loc, end, start); + auto weightedDelta = arith::MulFOp::create(b, loc, delta, weight); + return arith::AddFOp::create(b, loc, start, weightedDelta); } if (auto minimum = dyn_cast(op)) { Type dtype = cast(minimum.getType()).getDtype(); @@ -1123,7 +1124,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], elemTy); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], elemTy); Value pred = createLessThan(b, loc, dtype, lhs, rhs); - return b.create(loc, pred, lhs, rhs); + return arith::SelectOp::create(b, loc, pred, lhs, rhs); } if (auto maximum = dyn_cast(op)) { Type dtype = cast(maximum.getType()).getDtype(); @@ -1133,7 +1134,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], elemTy); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], elemTy); Value pred = createGreaterThan(b, loc, dtype, lhs, rhs); - return b.create(loc, pred, lhs, rhs); + return arith::SelectOp::create(b, loc, pred, lhs, rhs); } if (auto clamp = dyn_cast(op)) { AtenClampOp::Adaptor adaptor(operands); @@ -1166,15 +1167,15 @@ static Value createLinalgPayloadCalculationForElementwiseOp( if (isa(dtype)) { auto cmp = getMax ? arith::CmpFPredicate::UGT : arith::CmpFPredicate::ULT; - pred = b.create(loc, cmp, input, clamp); + pred = arith::CmpFOp::create(b, loc, cmp, input, clamp); } else if (isa(dtype)) { auto cmp = isUnsigned ? arith::CmpIPredicate::ult : arith::CmpIPredicate::slt; if (getMax) cmp = arith::invertPredicate(cmp); - pred = b.create(loc, cmp, input, clamp); + pred = arith::CmpIOp::create(b, loc, cmp, input, clamp); } - return b.create(loc, pred, clamp, input); + return arith::SelectOp::create(b, loc, pred, clamp, input); }; auto result = payloadArgs[0]; @@ -1203,36 +1204,36 @@ static Value createLinalgPayloadCalculationForElementwiseOp( auto minPromoted = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value pred; if (isa(dtype)) { - pred = b.create(loc, arith::CmpFPredicate::ULT, result, - minPromoted); + pred = arith::CmpFOp::create(b, loc, arith::CmpFPredicate::ULT, result, + minPromoted); } else if (isa(dtype)) { - pred = b.create(loc, arith::CmpIPredicate::slt, result, - minPromoted); + pred = arith::CmpIOp::create(b, loc, arith::CmpIPredicate::slt, result, + minPromoted); } else { clampTensor.emitError( "unimplemented: dtype other than float and integer " "types are not supported."); return nullptr; } - result = b.create(loc, pred, minPromoted, result); + result = arith::SelectOp::create(b, loc, pred, minPromoted, result); } if (!isa(max.getType())) { max = isMinNone ? payloadArgs[1] : payloadArgs[2]; auto maxPromoted = convertScalarToDtype(b, loc, max, dtype); Value pred; if (isa(dtype)) { - pred = b.create(loc, arith::CmpFPredicate::UGT, result, - maxPromoted); + pred = arith::CmpFOp::create(b, loc, arith::CmpFPredicate::UGT, result, + maxPromoted); } else if (isa(dtype)) { - pred = b.create(loc, arith::CmpIPredicate::sgt, result, - maxPromoted); + pred = arith::CmpIOp::create(b, loc, arith::CmpIPredicate::sgt, result, + maxPromoted); } else { clampTensor.emitError( "unimplemented: dtype other than float and integer " "types are not supported."); return nullptr; } - result = b.create(loc, pred, maxPromoted, result); + result = arith::SelectOp::create(b, loc, pred, maxPromoted, result); } return result; } @@ -1245,11 +1246,11 @@ static Value createLinalgPayloadCalculationForElementwiseOp( b, loc, operands[2], dtype, /*srcOriginalDtype=*/operands[2].getType(), /*dstOriginalDtype=*/dtype); if (isa(dtype)) { - Value mult = b.create(loc, self, alpha); - return b.create(loc, other, mult); + Value mult = arith::MulFOp::create(b, loc, self, alpha); + return arith::SubFOp::create(b, loc, other, mult); } else if (isa(dtype)) { - Value mult = b.create(loc, self, alpha); - return b.create(loc, other, mult); + Value mult = arith::MulIOp::create(b, loc, self, alpha); + return arith::SubIOp::create(b, loc, other, mult); } rsub.emitError("unimplemented: dtype other than float and integer " "types are not supported."); @@ -1262,9 +1263,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, operands[1], dtype); if (isa(dtype)) - return b.create(loc, lhs, rhs); + return arith::MulFOp::create(b, loc, lhs, rhs); if (isa(dtype)) - return b.create(loc, lhs, rhs); + return arith::MulIOp::create(b, loc, lhs, rhs); mulScalar.emitError("unimplemented: Only integer/float dtype supported"); return nullptr; } @@ -1304,7 +1305,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value other = convertScalarToDtype(b, loc, operands[1], dtype); - return b.create(loc, self, other); + return arith::DivFOp::create(b, loc, self, other); } if (auto remScalar = dyn_cast(op)) { return createRemainderPayload(b, loc, converter, payloadArgs, remScalar, @@ -1322,15 +1323,16 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Type elementType = arg.getType(); // assert(element != 0) auto zero = - b.create(loc, FloatAttr::get(elementType, 0.0)); + arith::ConstantOp::create(b, loc, FloatAttr::get(elementType, 0.0)); auto pred = - b.create(loc, arith::CmpFPredicate::ONE, arg, zero); - b.create( - loc, pred, b.getStringAttr("unimplemented: tensor with zero element")); + arith::CmpFOp::create(b, loc, arith::CmpFPredicate::ONE, arg, zero); + cf::AssertOp::create( + b, loc, pred, + b.getStringAttr("unimplemented: tensor with zero element")); auto one = - b.create(loc, FloatAttr::get(elementType, 1.0)); - return b.create(loc, one, arg); + arith::ConstantOp::create(b, loc, FloatAttr::get(elementType, 1.0)); + return arith::DivFOp::create(b, loc, one, arg); } if (auto thresholdOp = dyn_cast(op)) { // The approach used here is as follows: @@ -1347,12 +1349,12 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value predicate; if (isa(dtype)) - predicate = b.create(loc, arith::CmpFPredicate::ULE, self, - threshold); + predicate = arith::CmpFOp::create(b, loc, arith::CmpFPredicate::ULE, self, + threshold); else - predicate = b.create(loc, arith::CmpIPredicate::sle, self, - threshold); - return b.create(loc, predicate, value, self); + predicate = arith::CmpIOp::create(b, loc, arith::CmpIPredicate::sle, self, + threshold); + return arith::SelectOp::create(b, loc, predicate, value, self); } if (auto thresholdBackward = dyn_cast(op)) { // The approach used here is as follows: @@ -1366,16 +1368,17 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value self = convertScalarToDtype(b, loc, payloadArgs[1], dtype); Value threshold = convertScalarToDtype(b, loc, adaptor.getThreshold(), dtype); - Value constantZero = b.create(loc, b.getZeroAttr(dtype)); + Value constantZero = + arith::ConstantOp::create(b, loc, b.getZeroAttr(dtype)); Value predicate; if (isa(dtype)) - predicate = b.create(loc, arith::CmpFPredicate::ULE, self, - threshold); + predicate = arith::CmpFOp::create(b, loc, arith::CmpFPredicate::ULE, self, + threshold); else - predicate = b.create(loc, arith::CmpIPredicate::sle, self, - threshold); - return b.create(loc, predicate, constantZero, grad); + predicate = arith::CmpIOp::create(b, loc, arith::CmpIPredicate::sle, self, + threshold); + return arith::SelectOp::create(b, loc, predicate, constantZero, grad); } if (auto fillScalar = dyn_cast(op)) { AtenFillScalarOp::Adaptor adaptor(operands); @@ -1393,7 +1396,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Value input = payloadArgs[0]; Value mask = payloadArgs[1]; Value fillValue = convertScalarToDtype(b, loc, payloadArgs[2], dtype); - return b.create(loc, mask, fillValue, input); + return arith::SelectOp::create(b, loc, mask, fillValue, input); } if (auto fillTensor = dyn_cast(op)) { AtenFillTensorOp::Adaptor adaptor(operands); @@ -1428,11 +1431,12 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return nullptr; } - Value allOnesVal = b.create( - loc, b.getIntegerAttr( - elementType, - APSInt::getAllOnes(elementType.getIntOrFloatBitWidth()))); - return b.create(loc, payloadArgs[0], allOnesVal); + Value allOnesVal = arith::ConstantOp::create( + b, loc, + b.getIntegerAttr( + elementType, + APSInt::getAllOnes(elementType.getIntOrFloatBitWidth()))); + return arith::XOrIOp::create(b, loc, payloadArgs[0], allOnesVal); } if (isa(op)) { @@ -1463,9 +1467,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( if (valueTy != outIntTy) { if (torch_to_linalg::isUnsignedTorchType(qtensorTy)) { - value = b.create(loc, outIntTy, value); + value = arith::ExtUIOp::create(b, loc, outIntTy, value); } else { - value = b.create(loc, outIntTy, value); + value = arith::ExtSIOp::create(b, loc, outIntTy, value); } } @@ -1474,20 +1478,20 @@ static Value createLinalgPayloadCalculationForElementwiseOp( auto zpTy = zp.getType(); if (zpTy != outIntTy) { - zp = b.create(loc, outIntTy, zp); + zp = arith::TruncIOp::create(b, loc, outIntTy, zp); } - value = b.create(loc, value, zp); + value = arith::SubIOp::create(b, loc, value, zp); // treat the i32 as a signed int regardless of original signed-ness // this will prevent overflow from subtraction for unsigned quantizations. - value = b.create(loc, outFpTy, value); + value = arith::SIToFPOp::create(b, loc, outFpTy, value); scale = converter->materializeTargetConversion( b, loc, converter->convertType(scale.getType()), scale); if (scale.getType() != value.getType()) { - scale = b.create(loc, value.getType(), scale); + scale = arith::TruncFOp::create(b, loc, value.getType(), scale); } - value = b.create(loc, value, scale); + value = arith::MulFOp::create(b, loc, value, scale); return value; } @@ -1499,15 +1503,15 @@ static Value createLinalgPayloadCalculationForElementwiseOp( zp = converter->materializeTargetConversion( b, loc, converter->convertType(zp.getType()), zp); - zp = b.create(loc, valueTy, zp); + zp = arith::SIToFPOp::create(b, loc, valueTy, zp); scale = converter->materializeTargetConversion( b, loc, converter->convertType(scale.getType()), scale); - scale = b.create(loc, valueTy, scale); + scale = arith::TruncFOp::create(b, loc, valueTy, scale); - value = b.create(loc, value, scale); - value = b.create(loc, value); - value = b.create(loc, value, zp); + value = arith::DivFOp::create(b, loc, value, scale); + value = math::RoundEvenOp::create(b, loc, value); + value = arith::AddFOp::create(b, loc, value, zp); auto destTy = payloadArgs[1].getType(); auto bitwidth = destTy.getIntOrFloatBitWidth(); @@ -1522,16 +1526,16 @@ static Value createLinalgPayloadCalculationForElementwiseOp( double maxI = isUnsigned ? static_cast(max.getZExtValue()) : static_cast(max.getSExtValue()); Value minVal = - b.create(loc, b.getFloatAttr(valueTy, minI)); + arith::ConstantOp::create(b, loc, b.getFloatAttr(valueTy, minI)); Value maxVal = - b.create(loc, b.getFloatAttr(valueTy, maxI)); - value = b.create(loc, value, minVal); - value = b.create(loc, value, maxVal); + arith::ConstantOp::create(b, loc, b.getFloatAttr(valueTy, maxI)); + value = arith::MaximumFOp::create(b, loc, value, minVal); + value = arith::MinimumFOp::create(b, loc, value, maxVal); if (isUnsigned) { - value = b.create(loc, destTy, value); + value = arith::FPToUIOp::create(b, loc, destTy, value); } else { - value = b.create(loc, destTy, value); + value = arith::FPToSIOp::create(b, loc, destTy, value); } return value; @@ -1566,17 +1570,17 @@ static Value createLinalgPayloadCalculationForElementwiseOp( auto cvtArg1 = convertScalarToDtype(b, loc, payloadArgs[1], computeType); // Reference to the definition of torch.isclose: // ∣input − other∣ <= atol + rtol × ∣other∣ - auto diff = b.create(loc, computeType, cvtArg0, cvtArg1); - auto absDiff = b.create(loc, computeType, diff); + auto diff = arith::SubFOp::create(b, loc, computeType, cvtArg0, cvtArg1); + auto absDiff = math::AbsFOp::create(b, loc, computeType, diff); auto cstRtol = - b.create(loc, b.getFloatAttr(computeType, rtol)); - auto absOther = b.create(loc, computeType, cvtArg1); - auto mul = b.create(loc, computeType, cstRtol, absOther); + arith::ConstantOp::create(b, loc, b.getFloatAttr(computeType, rtol)); + auto absOther = math::AbsFOp::create(b, loc, computeType, cvtArg1); + auto mul = arith::MulFOp::create(b, loc, computeType, cstRtol, absOther); auto cstAtol = - b.create(loc, b.getFloatAttr(computeType, atol)); - auto threshold = b.create(loc, computeType, cstAtol, mul); - return b.create(loc, arith::CmpFPredicate::ULE, absDiff, - threshold); + arith::ConstantOp::create(b, loc, b.getFloatAttr(computeType, atol)); + auto threshold = arith::AddFOp::create(b, loc, computeType, cstAtol, mul); + return arith::CmpFOp::create(b, loc, arith::CmpFPredicate::ULE, absDiff, + threshold); } op->emitError("unimplemented lowering in " @@ -1659,7 +1663,7 @@ class ConvertElementwiseOp : public ConversionPattern { hadErrorCreatingPayload = true; return; } - b.create(loc, result); + linalg::YieldOp::create(b, loc, result); }); if (hadErrorCreatingPayload) return failure(); @@ -1714,36 +1718,37 @@ class ConvertAtenNllLossForwardOp getTypeConverter()->convertType(op->getResult(0).getType())); Type elementType = resultType.getElementType(); - Value zeroVal = rewriter.create( - loc, rewriter.getZeroAttr(elementType)); + Value zeroVal = arith::ConstantOp::create( + rewriter, loc, rewriter.getZeroAttr(elementType)); Value finalRes = torch_to_linalg::createElementwiseLinalgGeneric( rewriter, loc, {target}, elementType, [&](OpBuilder &b, Location loc, ValueRange args) { Value targetVal = args[0]; - Value indTarget = rewriter.create( - loc, rewriter.getIndexType(), targetVal); + Value indTarget = arith::IndexCastOp::create( + rewriter, loc, rewriter.getIndexType(), targetVal); // The final result is given by: // final_res = (indTarget == ignoreIndexVal) ? 0 : // input[indI][IndTarget] - Value cmpEq = rewriter.create( - loc, arith::CmpIPredicate::eq, indTarget, ignoreIndexVal); + Value cmpEq = + arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, + indTarget, ignoreIndexVal); SmallVector extractionIndices{indTarget}; if (inputRank == 2) { - Value indI = rewriter.create(loc, 0); + Value indI = linalg::IndexOp::create(rewriter, loc, 0); extractionIndices.insert(extractionIndices.begin(), indI); } - Value result = - rewriter.create(loc, input, extractionIndices); + Value result = tensor::ExtractOp::create(rewriter, loc, input, + extractionIndices); Value negate = - rewriter.create(loc, elementType, result); + arith::NegFOp::create(rewriter, loc, elementType, result); Value selectFinal = - rewriter.create(loc, cmpEq, zeroVal, negate); - b.create(loc, selectFinal); + arith::SelectOp::create(rewriter, loc, cmpEq, zeroVal, negate); + linalg::YieldOp::create(b, loc, selectFinal); }); llvm::iota_range dimsToReduce(0, targetRank, @@ -1753,26 +1758,27 @@ class ConvertAtenNllLossForwardOp if (reduction == torch_upstream::Reduction::Sum || reduction == torch_upstream::Reduction::Mean) { - Value zeroIVal = rewriter.create( - loc, rewriter.getZeroAttr(rewriter.getI32Type())); + Value zeroIVal = arith::ConstantOp::create( + rewriter, loc, rewriter.getZeroAttr(rewriter.getI32Type())); auto countInfo = torch_to_linalg::ReductionOpInfo{false, target, dimSet}; Value numOfElems = torch_to_linalg::createReductionLinalgGeneric( rewriter, loc, countInfo, /*initElem=*/zeroIVal, [&](OpBuilder &b, Location loc, ValueRange args) { Value targetVal = args[0]; - Value indTarget = rewriter.create( - loc, rewriter.getIndexType(), targetVal); - Value cmpEq = rewriter.create( - loc, arith::CmpIPredicate::ne, indTarget, ignoreIndexVal); - cmpEq = rewriter.create(loc, rewriter.getI32Type(), - cmpEq); - Value add = rewriter.create(loc, args[1], cmpEq); - rewriter.create(loc, add); + Value indTarget = arith::IndexCastOp::create( + rewriter, loc, rewriter.getIndexType(), targetVal); + Value cmpEq = + arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ne, + indTarget, ignoreIndexVal); + cmpEq = arith::ExtUIOp::create(rewriter, loc, rewriter.getI32Type(), + cmpEq); + Value add = arith::AddIOp::create(rewriter, loc, args[1], cmpEq); + linalg::YieldOp::create(rewriter, loc, add); }); - numOfElems = rewriter.create( - loc, rewriter.getI32Type(), numOfElems, ArrayRef{}); + numOfElems = tensor::ExtractOp::create( + rewriter, loc, rewriter.getI32Type(), numOfElems, ArrayRef{}); numOfElems = convertScalarToDtype(rewriter, loc, numOfElems, elementType); auto opInfo = torch_to_linalg::ReductionOpInfo{false, finalRes, dimSet}; @@ -1783,9 +1789,9 @@ class ConvertAtenNllLossForwardOp Value newVal = args[0]; Value accumulator = args[1]; if (reduction == torch_upstream::Reduction::Mean) - newVal = b.create(loc, newVal, numOfElems); - Value result = b.create(loc, newVal, accumulator); - b.create(loc, result); + newVal = arith::DivFOp::create(b, loc, newVal, numOfElems); + Value result = arith::AddFOp::create(b, loc, newVal, accumulator); + linalg::YieldOp::create(b, loc, result); }); } @@ -1805,14 +1811,14 @@ class ConvertAtenNllLossForwardOp Value numIgnoredIndex; if (targetRank == 0) { - Value targetVal = rewriter.create(loc, target); - numIgnoredIndex = rewriter.create( - loc, arith::CmpIPredicate::eq, targetVal, ignoreIndex); + Value targetVal = tensor::ExtractOp::create(rewriter, loc, target); + numIgnoredIndex = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::eq, targetVal, ignoreIndex); numIgnoredIndex = convertScalarToDtype(rewriter, loc, numIgnoredIndex, ignoreIndex.getType()); } else { - Value zeroCstInt = rewriter.create( - loc, rewriter.getZeroAttr(ignoreIndex.getType())); + Value zeroCstInt = arith::ConstantOp::create( + rewriter, loc, rewriter.getZeroAttr(ignoreIndex.getType())); auto opInfo = torch_to_linalg::ReductionOpInfo{/*keepDim=*/false, target, dimSet}; @@ -1822,23 +1828,23 @@ class ConvertAtenNllLossForwardOp [&](OpBuilder &b, Location loc, ValueRange args) { Value targetVal = args[0]; Value accumulator = args[1]; - Value result = b.create( - loc, arith::CmpIPredicate::eq, targetVal, ignoreIndex); - result = b.create( - loc, + Value result = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::eq, targetVal, ignoreIndex); + result = arith::AddIOp::create( + b, loc, convertScalarToDtype(rewriter, loc, result, ignoreIndex.getType()), accumulator); - b.create(loc, result); + linalg::YieldOp::create(b, loc, result); }); numIgnoredIndex = - rewriter.create(loc, numIgnoredIndex); + tensor::ExtractOp::create(rewriter, loc, numIgnoredIndex); } Value numtargetElems = getTensorSize(rewriter, loc, target); Value totalWeightVal = - rewriter.create(loc, numtargetElems, numIgnoredIndex); + arith::SubIOp::create(rewriter, loc, numtargetElems, numIgnoredIndex); Value totalWeight = createInitTensor( rewriter, loc, {}, elementType, convertScalarToDtype(rewriter, loc, totalWeightVal, elementType)); @@ -1853,9 +1859,9 @@ class ConvertAtenNllLossForwardOp static Value calculateRSTD(OpBuilder &b, Location loc, Type elemTy, Value eps, Value var) { // The eps is always f64. - Value truncatedEps = b.create(loc, elemTy, eps); - Value varPlusEps = b.create(loc, var, truncatedEps); - Value rSTD = b.create(loc, varPlusEps); + Value truncatedEps = arith::TruncFOp::create(b, loc, elemTy, eps); + Value varPlusEps = arith::AddFOp::create(b, loc, var, truncatedEps); + Value rSTD = math::RsqrtOp::create(b, loc, varPlusEps); return rSTD; } @@ -1864,10 +1870,10 @@ static Value calculateRSTD(OpBuilder &b, Location loc, Type elemTy, Value eps, static Value createLinalgPayloadCalculationForNormOpsWithRSTD( OpBuilder &b, Location loc, Type elemTy, Value input, Value mean, Value rSTD, Value eps, Value weight, Value bias) { - Value inputSubMean = b.create(loc, input, mean); - Value temp = b.create(loc, inputSubMean, rSTD); - Value timesWeight = b.create(loc, temp, weight); - Value plusBias = b.create(loc, timesWeight, bias); + Value inputSubMean = arith::SubFOp::create(b, loc, input, mean); + Value temp = arith::MulFOp::create(b, loc, inputSubMean, rSTD); + Value timesWeight = arith::MulFOp::create(b, loc, temp, weight); + Value plusBias = arith::AddFOp::create(b, loc, timesWeight, bias); return plusBias; } @@ -1926,22 +1932,22 @@ class ConvertAtenBatchNormOp : public OpConversionPattern { } // TODO: Add support for training. - auto constFalse = rewriter.create( - loc, IntegerAttr::get(IntegerType::get(context, 1), 0)); - auto trainingFalse = rewriter.create( - loc, arith::CmpIPredicate::eq, training, constFalse); - rewriter.create( - loc, trainingFalse, + auto constFalse = arith::ConstantOp::create( + rewriter, loc, IntegerAttr::get(IntegerType::get(context, 1), 0)); + auto trainingFalse = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::eq, training, constFalse); + cf::AssertOp::create( + rewriter, loc, trainingFalse, rewriter.getStringAttr("training is not supported for now")); // num_features – C from an expected input of size (N,C,D,H,W ...) - Value numFeatures = rewriter.create(loc, input, 1); + Value numFeatures = tensor::DimOp::create(rewriter, loc, input, 1); auto contractingDim0EqualsNumFeatures = [&](Value v) { - auto dim0 = rewriter.create(loc, v, 0); - auto dim0Equal = rewriter.create( - loc, arith::CmpIPredicate::eq, numFeatures, dim0); - rewriter.create( - loc, dim0Equal, + auto dim0 = tensor::DimOp::create(rewriter, loc, v, 0); + auto dim0Equal = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::eq, numFeatures, dim0); + cf::AssertOp::create( + rewriter, loc, dim0Equal, rewriter.getStringAttr( "expect the size of dim 0 equal to the number of features")); }; @@ -1966,21 +1972,18 @@ class ConvertAtenBatchNormOp : public OpConversionPattern { SmallVector iteratorTypes( inputRank, utils::IteratorType::parallel); Value batchNorm = - rewriter - .create( - loc, input.getType(), - ValueRange{input, weight, bias, runningMean, runningVar}, input, - /*indexingMaps=*/indexingMaps, - /*iteratorTypes=*/iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value input = args[0], weight = args[1], bias = args[2], - mean = args[3], var = args[4]; - Value result = - createLinalgPayloadCalculationForNormOpsWithVar( - b, loc, var.getType(), input, mean, var, eps, weight, - bias); - b.create(loc, result); - }) + linalg::GenericOp::create( + rewriter, loc, input.getType(), + ValueRange{input, weight, bias, runningMean, runningVar}, input, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value input = args[0], weight = args[1], bias = args[2], + mean = args[3], var = args[4]; + Value result = createLinalgPayloadCalculationForNormOpsWithVar( + b, loc, var.getType(), input, mean, var, eps, weight, bias); + linalg::YieldOp::create(b, loc, result); + }) .getResult(0); Type newResultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, newResultType, batchNorm); @@ -2085,45 +2088,42 @@ class ConvertAtenNllLossBackwardOp // NOTE: In the case of not batch dimension, `batch_index` essentially // becomes zero. Value gradInput = - rewriter - .create( - loc, gradInputTensor.getType(), - ValueRange{gradOutput, target, totalWeight}, gradInputTensor, - indexingMaps, iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value gradOutElem = args[0]; - Value targetElem = castIntToIndex(b, loc, args[1]); - Value totalWeightElem = args[2]; - Value classIndex = - b.create(loc, inputRank - 1); - - if (reduction == torch_upstream::Reduction::Mean) { - gradOutElem = b.create(loc, gradOutElem, - totalWeightElem); - } - - Value negGradOutElem = - b.create(loc, gradOutElem); - Value weightElem = getConstant(b, loc, 1, resultElementType); - if (!weightIsNone) { - weightElem = - b.create(loc, weight, targetElem); - } - Value weightedNegGradOutElem = - b.create(loc, weightElem, negGradOutElem); - - Value targetNeqClassIndex = b.create( - loc, arith::CmpIPredicate::ne, targetElem, classIndex); - Value targetEqIgnoreIndex = b.create( - loc, arith::CmpIPredicate::eq, targetElem, ignoreIndex); - Value gradInputIsZero = b.create( - loc, targetNeqClassIndex, targetEqIgnoreIndex); - - Value zero = getConstant(b, loc, 0, resultElementType); - Value gradInElem = b.create( - loc, gradInputIsZero, zero, weightedNegGradOutElem); - b.create(loc, gradInElem); - }) + linalg::GenericOp::create( + rewriter, loc, gradInputTensor.getType(), + ValueRange{gradOutput, target, totalWeight}, gradInputTensor, + indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value gradOutElem = args[0]; + Value targetElem = castIntToIndex(b, loc, args[1]); + Value totalWeightElem = args[2]; + Value classIndex = linalg::IndexOp::create(b, loc, inputRank - 1); + + if (reduction == torch_upstream::Reduction::Mean) { + gradOutElem = + arith::DivFOp::create(b, loc, gradOutElem, totalWeightElem); + } + + Value negGradOutElem = arith::NegFOp::create(b, loc, gradOutElem); + Value weightElem = getConstant(b, loc, 1, resultElementType); + if (!weightIsNone) { + weightElem = + tensor::ExtractOp::create(b, loc, weight, targetElem); + } + Value weightedNegGradOutElem = + arith::MulFOp::create(b, loc, weightElem, negGradOutElem); + + Value targetNeqClassIndex = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::ne, targetElem, classIndex); + Value targetEqIgnoreIndex = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::eq, targetElem, ignoreIndex); + Value gradInputIsZero = arith::OrIOp::create( + b, loc, targetNeqClassIndex, targetEqIgnoreIndex); + + Value zero = getConstant(b, loc, 0, resultElementType); + Value gradInElem = arith::SelectOp::create( + b, loc, gradInputIsZero, zero, weightedNegGradOutElem); + linalg::YieldOp::create(b, loc, gradInElem); + }) ->getResult(0); RankedTensorType resultType = cast( @@ -2320,43 +2320,41 @@ class ConvertLogitOp : public OpConversionPattern { SmallVector iteratorTypes( inputRank, utils::IteratorType::parallel); Value logit = - rewriter - .create( - loc, input.getType(), - /*ins=*/input, - /*outs=*/input, - /*indexingMaps=*/indexingMaps, - /*iteratorTypes=*/iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value input = args[0]; - - TypedAttr oneAttr = b.getFloatAttr(inputElementType, 1.0); - Value oneValue = b.create(loc, oneAttr); - - Value zI; - if (!handleEps) { - zI = input; - } else { - Value truncEps = - b.create(loc, inputElementType, eps); - Value oneMinusEps = - b.create(loc, oneValue, truncEps); - - Value min = - b.create(loc, input, oneMinusEps); - Value clampedInput = - b.create(loc, min, truncEps); - - zI = clampedInput; - } - - Value probability = - b.create(loc, oneValue, zI); - Value odds = b.create(loc, zI, probability); - Value result = b.create(loc, odds); - - b.create(loc, result); - }) + linalg::GenericOp::create( + rewriter, loc, input.getType(), + /*ins=*/input, + /*outs=*/input, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value input = args[0]; + + TypedAttr oneAttr = b.getFloatAttr(inputElementType, 1.0); + Value oneValue = arith::ConstantOp::create(b, loc, oneAttr); + + Value zI; + if (!handleEps) { + zI = input; + } else { + Value truncEps = + arith::TruncFOp::create(b, loc, inputElementType, eps); + Value oneMinusEps = + arith::SubFOp::create(b, loc, oneValue, truncEps); + + Value min = + arith::MinimumFOp::create(b, loc, input, oneMinusEps); + Value clampedInput = + arith::MaximumFOp::create(b, loc, min, truncEps); + + zI = clampedInput; + } + + Value probability = arith::SubFOp::create(b, loc, oneValue, zI); + Value odds = arith::DivFOp::create(b, loc, zI, probability); + Value result = math::LogOp::create(b, loc, odds); + + linalg::YieldOp::create(b, loc, result); + }) .getResult(0); Type newResultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, newResultType, logit); @@ -2422,7 +2420,8 @@ class ConvertDequantizePerChannel llvm::SmallVector dynSizes; for (auto [index, dim] : llvm::enumerate(resultType.getShape())) { if (ShapedType::isDynamic(dim)) { - dynSizes.push_back(rewriter.create(loc, operand, index)); + dynSizes.push_back( + tensor::DimOp::create(rewriter, loc, operand, index)); } } @@ -2437,37 +2436,37 @@ class ConvertDequantizePerChannel maps[2] = broadcastMap; auto empty = - rewriter.create(op.getLoc(), resultType, dynSizes); - auto linalgOp = rewriter.create( - loc, resultType, ValueRange{operand, scale, zeropoint}, + tensor::EmptyOp::create(rewriter, op.getLoc(), resultType, dynSizes); + auto linalgOp = linalg::GenericOp::create( + rewriter, loc, resultType, ValueRange{operand, scale, zeropoint}, ValueRange{empty}, maps, iterators, [&](OpBuilder &b, Location loc, ValueRange args) { Value operand = args[0]; Value scale = args[1]; Value zeropoint = args[2]; if (operandDTy.isUnsignedInteger(8)) { - operand = b.create(loc, b.getI32Type(), operand); + operand = arith::ExtUIOp::create(b, loc, b.getI32Type(), operand); } else if (operandDTy.isSignedInteger(8)) { - operand = b.create(loc, b.getI32Type(), operand); + operand = arith::ExtSIOp::create(b, loc, b.getI32Type(), operand); } if (zeropointDTy.isUnsignedInteger(8)) { zeropoint = - b.create(loc, b.getI32Type(), zeropoint); + arith::ExtUIOp::create(b, loc, b.getI32Type(), zeropoint); } else if (zeropointDTy.isSignedInteger(8)) { zeropoint = - b.create(loc, b.getI32Type(), zeropoint); + arith::ExtSIOp::create(b, loc, b.getI32Type(), zeropoint); } else if (zeropointDTy.isInteger(64)) { zeropoint = - b.create(loc, b.getI32Type(), zeropoint); + arith::TruncIOp::create(b, loc, b.getI32Type(), zeropoint); op->emitWarning() << "truncated zero point from 64 to 32 bit"; } - Value sub = rewriter.create(loc, operand, zeropoint); + Value sub = arith::SubIOp::create(rewriter, loc, operand, zeropoint); Value fp = - rewriter.create(loc, args[3].getType(), sub); - Value mul = rewriter.create(loc, fp, scale); - b.create(loc, mul); + arith::SIToFPOp::create(rewriter, loc, args[3].getType(), sub); + Value mul = arith::MulFOp::create(rewriter, loc, fp, scale); + linalg::YieldOp::create(b, loc, mul); }); rewriter.replaceOp(op, linalgOp.getResults()); return success(); @@ -2505,33 +2504,33 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { Location loc = op->getLoc(); Type int64type = rewriter.getI64Type(); Type floatType = rewriter.getF32Type(); - Value oneIndex = rewriter.create(loc, 1); - Value zeroFloat = rewriter.create( - loc, rewriter.getFloatAttr(floatType, 0.0)); - Value oneFloat = rewriter.create( - loc, rewriter.getFloatAttr(floatType, 1.0)); - Value twoFloat = rewriter.create( - loc, rewriter.getFloatAttr(floatType, 2.0)); + Value oneIndex = arith::ConstantIndexOp::create(rewriter, loc, 1); + Value zeroFloat = arith::ConstantOp::create( + rewriter, loc, rewriter.getFloatAttr(floatType, 0.0)); + Value oneFloat = arith::ConstantOp::create( + rewriter, loc, rewriter.getFloatAttr(floatType, 1.0)); + Value twoFloat = arith::ConstantOp::create( + rewriter, loc, rewriter.getFloatAttr(floatType, 2.0)); Value input = adaptor.getInput(); auto inputType = cast(input.getType()); - Value innerDim0a = rewriter.create(loc, input, 2); - Value innerDim1a = rewriter.create(loc, input, 3); + Value innerDim0a = tensor::DimOp::create(rewriter, loc, input, 2); + Value innerDim1a = tensor::DimOp::create(rewriter, loc, input, 3); Value innerDim0b = - rewriter.create(loc, innerDim0a, oneIndex); + arith::SubIOp::create(rewriter, loc, innerDim0a, oneIndex); Value innerDim1b = - rewriter.create(loc, innerDim1a, oneIndex); + arith::SubIOp::create(rewriter, loc, innerDim1a, oneIndex); Value innerDim0c = - rewriter.create(loc, int64type, innerDim0b); + arith::IndexCastOp::create(rewriter, loc, int64type, innerDim0b); Value innerDim1c = - rewriter.create(loc, int64type, innerDim1b); + arith::IndexCastOp::create(rewriter, loc, int64type, innerDim1b); Value innerDim0d = - rewriter.create(loc, floatType, innerDim0c); + arith::SIToFPOp::create(rewriter, loc, floatType, innerDim0c); Value innerDim1d = - rewriter.create(loc, floatType, innerDim1c); + arith::SIToFPOp::create(rewriter, loc, floatType, innerDim1c); Value innerDim0e = - rewriter.create(loc, innerDim0d, twoFloat); + arith::DivFOp::create(rewriter, loc, innerDim0d, twoFloat); Value innerDim1e = - rewriter.create(loc, innerDim1d, twoFloat); + arith::DivFOp::create(rewriter, loc, innerDim1d, twoFloat); Value grid = adaptor.getGrid(); auto gridType = cast(grid.getType()); auto gridRank = gridType.getRank(); @@ -2552,26 +2551,26 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { auto lambdaExtract = [](OpBuilder &b, Location loc, Value input, Value idxA, Value idxB, Value idxC, Value idxD) -> Value { SmallVector index{idxA, idxB, idxC, idxD}; - Value result = b.create(loc, input, index); + Value result = tensor::ExtractOp::create(b, loc, input, index); return result; }; auto lambdaLinear = [&](OpBuilder &b, Location loc, Value x, Value y, Value d) -> Value { - Value dm = b.create(loc, oneFloat, d); - Value ra = b.create(loc, x, dm); - Value rb = b.create(loc, y, d); - Value res = b.create(loc, ra, rb); + Value dm = arith::SubFOp::create(b, loc, oneFloat, d); + Value ra = arith::MulFOp::create(b, loc, x, dm); + Value rb = arith::MulFOp::create(b, loc, y, d); + Value res = arith::AddFOp::create(b, loc, ra, rb); return res; }; auto lambdaNearest = [&](OpBuilder &b, Location loc, Value x, Value y, Value d) -> Value { - Value halfConst = rewriter.create( - loc, rewriter.getFloatAttr(floatType, 0.5)); - Value checkClosest = - b.create(loc, arith::CmpFPredicate::OLT, d, halfConst); - Value res = b.create(loc, checkClosest, x, y); + Value halfConst = arith::ConstantOp::create( + rewriter, loc, rewriter.getFloatAttr(floatType, 0.5)); + Value checkClosest = arith::CmpFOp::create( + b, loc, arith::CmpFPredicate::OLT, d, halfConst); + Value res = arith::SelectOp::create(b, loc, checkClosest, x, y); return res; }; @@ -2580,10 +2579,10 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { Value linear = lambdaLinear(b, loc, x, y, d); Value nearest = lambdaNearest(b, loc, x, y, d); Value zeroInt = - b.create(loc, b.getIntegerAttr(int64type, 0)); - Value checkMode = b.create(loc, arith::CmpIPredicate::eq, - iMode, zeroInt); - Value res = b.create(loc, checkMode, linear, nearest); + arith::ConstantOp::create(b, loc, b.getIntegerAttr(int64type, 0)); + Value checkMode = arith::CmpIOp::create(b, loc, arith::CmpIPredicate::eq, + iMode, zeroInt); + Value res = arith::SelectOp::create(b, loc, checkMode, linear, nearest); return res; }; @@ -2593,108 +2592,110 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { Value interMode = adaptor.getInterpolationMode(); SmallVector dynamicSizes{}; if (resultType.isDynamicDim(0)) - dynamicSizes.push_back(rewriter.create(loc, input, 0)); + dynamicSizes.push_back(tensor::DimOp::create(rewriter, loc, input, 0)); if (resultType.isDynamicDim(1)) - dynamicSizes.push_back(rewriter.create(loc, input, 1)); + dynamicSizes.push_back(tensor::DimOp::create(rewriter, loc, input, 1)); if (resultType.isDynamicDim(2)) - dynamicSizes.push_back(rewriter.create(loc, grid, 1)); + dynamicSizes.push_back(tensor::DimOp::create(rewriter, loc, grid, 1)); if (resultType.isDynamicDim(3)) - dynamicSizes.push_back(rewriter.create(loc, grid, 2)); + dynamicSizes.push_back(tensor::DimOp::create(rewriter, loc, grid, 2)); tensor::EmptyOp emptyOp = - rewriter.create(loc, resultType, dynamicSizes); - auto sGrid = rewriter.create( - loc, TypeRange{resultType}, ValueRange{grid, grid}, ValueRange(emptyOp), - gridMaps, gridIterators, + tensor::EmptyOp::create(rewriter, loc, resultType, dynamicSizes); + auto sGrid = linalg::GenericOp::create( + rewriter, loc, TypeRange{resultType}, ValueRange{grid, grid}, + ValueRange(emptyOp), gridMaps, gridIterators, [&](OpBuilder &b, Location loc, ValueRange args) { Value gr0 = args[1]; Value gr1 = args[0]; - Value gr0Half = b.create(loc, gr0, twoFloat); - Value gr1Half = b.create(loc, gr1, twoFloat); + Value gr0Half = arith::DivFOp::create(b, loc, gr0, twoFloat); + Value gr1Half = arith::DivFOp::create(b, loc, gr1, twoFloat); Value gr0HalfSelect = - b.create(loc, alignCorners, zeroFloat, gr0Half); + arith::SelectOp::create(b, loc, alignCorners, zeroFloat, gr0Half); Value gr1HalfSelect = - b.create(loc, alignCorners, zeroFloat, gr1Half); - Value gplus0 = b.create(loc, gr0, oneFloat); - Value gplus1 = b.create(loc, gr1, oneFloat); - Value gPlusMul0 = b.create(loc, gplus0, innerDim0e); - Value gPlusMul1 = b.create(loc, gplus1, innerDim1e); + arith::SelectOp::create(b, loc, alignCorners, zeroFloat, gr1Half); + Value gplus0 = arith::AddFOp::create(b, loc, gr0, oneFloat); + Value gplus1 = arith::AddFOp::create(b, loc, gr1, oneFloat); + Value gPlusMul0 = arith::MulFOp::create(b, loc, gplus0, innerDim0e); + Value gPlusMul1 = arith::MulFOp::create(b, loc, gplus1, innerDim1e); Value result0 = - b.create(loc, gPlusMul0, gr0HalfSelect); + arith::AddFOp::create(b, loc, gPlusMul0, gr0HalfSelect); Value result1 = - b.create(loc, gPlusMul1, gr1HalfSelect); - Value checkLowerBound0 = b.create( - loc, arith::CmpFPredicate::OLT, result0, zeroFloat); - Value checkLowerBound1 = b.create( - loc, arith::CmpFPredicate::OLT, result1, zeroFloat); - Value lowerOrig0 = b.create(loc, int64type, result0); - Value lowerOrig1 = b.create(loc, int64type, result1); + arith::AddFOp::create(b, loc, gPlusMul1, gr1HalfSelect); + Value checkLowerBound0 = arith::CmpFOp::create( + b, loc, arith::CmpFPredicate::OLT, result0, zeroFloat); + Value checkLowerBound1 = arith::CmpFOp::create( + b, loc, arith::CmpFPredicate::OLT, result1, zeroFloat); + Value lowerOrig0 = + arith::FPToSIOp::create(b, loc, int64type, result0); + Value lowerOrig1 = + arith::FPToSIOp::create(b, loc, int64type, result1); Value zeroInt = - b.create(loc, b.getIntegerAttr(int64type, 0)); + arith::ConstantOp::create(b, loc, b.getIntegerAttr(int64type, 0)); Value oneInt = - b.create(loc, b.getIntegerAttr(int64type, 1)); - Value lowerSub0 = b.create(loc, lowerOrig0, oneInt); - Value lowerSub1 = b.create(loc, lowerOrig1, oneInt); - Value lower0 = b.create(loc, checkLowerBound0, - lowerSub0, lowerOrig0); - Value lower1 = b.create(loc, checkLowerBound1, - lowerSub1, lowerOrig1); - Value lowerValid0 = - b.create(loc, checkLowerBound0, zeroInt, lower0); - Value lowerValid1 = - b.create(loc, checkLowerBound1, zeroInt, lower1); + arith::ConstantOp::create(b, loc, b.getIntegerAttr(int64type, 1)); + Value lowerSub0 = arith::SubIOp::create(b, loc, lowerOrig0, oneInt); + Value lowerSub1 = arith::SubIOp::create(b, loc, lowerOrig1, oneInt); + Value lower0 = arith::SelectOp::create(b, loc, checkLowerBound0, + lowerSub0, lowerOrig0); + Value lower1 = arith::SelectOp::create(b, loc, checkLowerBound1, + lowerSub1, lowerOrig1); + Value lowerValid0 = arith::SelectOp::create(b, loc, checkLowerBound0, + zeroInt, lower0); + Value lowerValid1 = arith::SelectOp::create(b, loc, checkLowerBound1, + zeroInt, lower1); Value upper0 = - b.create(loc, int64type, lower0, oneInt); + arith::AddIOp::create(b, loc, int64type, lower0, oneInt); Value upper1 = - b.create(loc, int64type, lower1, oneInt); - Value notValidUpper0 = rewriter.create( - loc, arith::CmpIPredicate::sgt, upper0, innerDim0c); - Value notValidUpper1 = rewriter.create( - loc, arith::CmpIPredicate::sgt, upper1, innerDim1c); + arith::AddIOp::create(b, loc, int64type, lower1, oneInt); + Value notValidUpper0 = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sgt, upper0, innerDim0c); + Value notValidUpper1 = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sgt, upper1, innerDim1c); Value upperValid0 = - b.create(loc, notValidUpper0, lower0, upper0); + arith::SelectOp::create(b, loc, notValidUpper0, lower0, upper0); Value upperValid1 = - b.create(loc, notValidUpper1, lower1, upper1); + arith::SelectOp::create(b, loc, notValidUpper1, lower1, upper1); Value lw0 = - b.create(loc, b.getIndexType(), lowerValid0); + arith::IndexCastOp::create(b, loc, b.getIndexType(), lowerValid0); Value lw1 = - b.create(loc, b.getIndexType(), lowerValid1); + arith::IndexCastOp::create(b, loc, b.getIndexType(), lowerValid1); Value up0 = - b.create(loc, b.getIndexType(), upperValid0); + arith::IndexCastOp::create(b, loc, b.getIndexType(), upperValid0); Value up1 = - b.create(loc, b.getIndexType(), upperValid1); - Value N = b.create(loc, 0); - Value C = b.create(loc, 1); + arith::IndexCastOp::create(b, loc, b.getIndexType(), upperValid1); + Value N = linalg::IndexOp::create(b, loc, 0); + Value C = linalg::IndexOp::create(b, loc, 1); Value result00 = lambdaExtract(b, loc, input, N, C, lw0, lw1); - Value result00a = b.create(loc, checkLowerBound0, - zeroFloat, result00); - Value result00b = b.create(loc, checkLowerBound1, - zeroFloat, result00a); + Value result00a = arith::SelectOp::create(b, loc, checkLowerBound0, + zeroFloat, result00); + Value result00b = arith::SelectOp::create(b, loc, checkLowerBound1, + zeroFloat, result00a); Value result01 = lambdaExtract(b, loc, input, N, C, lw0, up1); - Value result01a = b.create(loc, notValidUpper1, - zeroFloat, result01); - Value result01b = b.create(loc, checkLowerBound0, - zeroFloat, result01a); + Value result01a = arith::SelectOp::create(b, loc, notValidUpper1, + zeroFloat, result01); + Value result01b = arith::SelectOp::create(b, loc, checkLowerBound0, + zeroFloat, result01a); Value result10 = lambdaExtract(b, loc, input, N, C, up0, lw1); - Value result10a = b.create(loc, notValidUpper0, - zeroFloat, result10); - Value result10b = b.create(loc, checkLowerBound1, - zeroFloat, result10a); + Value result10a = arith::SelectOp::create(b, loc, notValidUpper0, + zeroFloat, result10); + Value result10b = arith::SelectOp::create(b, loc, checkLowerBound1, + zeroFloat, result10a); Value result11 = lambdaExtract(b, loc, input, N, C, up0, up1); - Value result11a = b.create(loc, notValidUpper0, - zeroFloat, result11); - Value result11b = b.create(loc, notValidUpper1, - zeroFloat, result11a); - Value lw0a = b.create(loc, floatType, lower0); - Value lw1a = b.create(loc, floatType, lower1); - Value d1 = b.create(loc, result0, lw0a); - Value d0 = b.create(loc, result1, lw1a); + Value result11a = arith::SelectOp::create(b, loc, notValidUpper0, + zeroFloat, result11); + Value result11b = arith::SelectOp::create(b, loc, notValidUpper1, + zeroFloat, result11a); + Value lw0a = arith::SIToFPOp::create(b, loc, floatType, lower0); + Value lw1a = arith::SIToFPOp::create(b, loc, floatType, lower1); + Value d1 = arith::SubFOp::create(b, loc, result0, lw0a); + Value d0 = arith::SubFOp::create(b, loc, result1, lw1a); Value resultScaled0 = lambdaInterpolate(b, loc, interMode, result00b, result01b, d0); Value resultScaled1 = lambdaInterpolate(b, loc, interMode, result10b, result11b, d0); Value resultScaled = lambdaInterpolate( b, loc, interMode, resultScaled0, resultScaled1, d1); - b.create(loc, resultScaled); + linalg::YieldOp::create(b, loc, resultScaled); }); rewriter.replaceOp(op, sGrid.getResults()); return success(); @@ -2713,36 +2714,36 @@ static Value nearestInterpolate(OpBuilder &b, Location loc, SmallVector indices; for (unsigned i = 0; i < inputRank; i++) { - indices.push_back(b.create(loc, i)); + indices.push_back(linalg::IndexOp::create(b, loc, i)); } for (unsigned i = 2; i < inputRank; i++) { Value outIndex = indices[i]; Value inputSizeFP = - b.create(loc, b.getF32Type(), inputSizes[i - 2]); + arith::SIToFPOp::create(b, loc, b.getF32Type(), inputSizes[i - 2]); Value outputSizeFP = - b.create(loc, b.getF32Type(), outputSizes[i - 2]); + arith::SIToFPOp::create(b, loc, b.getF32Type(), outputSizes[i - 2]); // scale = length_resized / length_original // x_original = x_resized / scale Value scale; if (scaleValues.empty()) - scale = b.create(loc, outputSizeFP, inputSizeFP); + scale = arith::DivFOp::create(b, loc, outputSizeFP, inputSizeFP); else scale = scaleValues[i - 2]; - Value outInt = b.create(loc, b.getI64Type(), outIndex); - Value outFP = b.create(loc, b.getF32Type(), outInt); + Value outInt = arith::IndexCastOp::create(b, loc, b.getI64Type(), outIndex); + Value outFP = arith::SIToFPOp::create(b, loc, b.getF32Type(), outInt); Value proj; if (coordStr.empty() || coordStr == "_asymmetric") { - proj = b.create(loc, outFP, scale); + proj = arith::DivFOp::create(b, loc, outFP, scale); } else if (coordStr == "_half_pixel") { - Value cstHalf = b.create(loc, b.getF32FloatAttr(0.5)); - Value add = b.create(loc, outFP, cstHalf); - Value div = b.create(loc, add, scale); - proj = b.create(loc, div, cstHalf); + Value cstHalf = arith::ConstantOp::create(b, loc, b.getF32FloatAttr(0.5)); + Value add = arith::AddFOp::create(b, loc, outFP, cstHalf); + Value div = arith::DivFOp::create(b, loc, add, scale); + proj = arith::SubFOp::create(b, loc, div, cstHalf); } else { llvm_unreachable("Unsupported coordination transformation mode"); } @@ -2750,43 +2751,43 @@ static Value nearestInterpolate(OpBuilder &b, Location loc, Value nearestFP; // get nearest pixel using floor if (nearestMode == "floor" || nearestMode == "") { - nearestFP = b.create(loc, proj); + nearestFP = math::FloorOp::create(b, loc, proj); } else if (nearestMode == "round_prefer_floor") { - Value cstHalf = b.create(loc, b.getF32FloatAttr(0.5)); - Value floor = b.create(loc, proj); - Value ceil = b.create(loc, proj); - Value decimal = b.create(loc, proj, floor); - Value cmp = b.create(loc, arith::CmpFPredicate::ULE, - decimal, cstHalf); - nearestFP = b.create(loc, cmp, floor, ceil); + Value cstHalf = arith::ConstantOp::create(b, loc, b.getF32FloatAttr(0.5)); + Value floor = math::FloorOp::create(b, loc, proj); + Value ceil = math::CeilOp::create(b, loc, proj); + Value decimal = arith::SubFOp::create(b, loc, proj, floor); + Value cmp = arith::CmpFOp::create(b, loc, arith::CmpFPredicate::ULE, + decimal, cstHalf); + nearestFP = arith::SelectOp::create(b, loc, cmp, floor, ceil); } else if (nearestMode == "round_prefer_ceil") { - Value cstHalf = b.create(loc, b.getF32FloatAttr(0.5)); - Value cstOne = b.create(loc, b.getF32FloatAttr(1)); - Value floor = b.create(loc, proj); - Value ceil = b.create(loc, proj); - Value decimal = b.create(loc, proj, floor); - Value cmp = b.create(loc, arith::CmpFPredicate::UGE, - decimal, cstHalf); - nearestFP = b.create(loc, cmp, ceil, floor); - Value inputSizeMOne = b.create(loc, inputSizeFP, cstOne); + Value cstHalf = arith::ConstantOp::create(b, loc, b.getF32FloatAttr(0.5)); + Value cstOne = arith::ConstantOp::create(b, loc, b.getF32FloatAttr(1)); + Value floor = math::FloorOp::create(b, loc, proj); + Value ceil = math::CeilOp::create(b, loc, proj); + Value decimal = arith::SubFOp::create(b, loc, proj, floor); + Value cmp = arith::CmpFOp::create(b, loc, arith::CmpFPredicate::UGE, + decimal, cstHalf); + nearestFP = arith::SelectOp::create(b, loc, cmp, ceil, floor); + Value inputSizeMOne = arith::SubFOp::create(b, loc, inputSizeFP, cstOne); // don't extract out of bounds - nearestFP = b.create(loc, nearestFP, inputSizeMOne); + nearestFP = arith::MinimumFOp::create(b, loc, nearestFP, inputSizeMOne); } else if (nearestMode == "ceil") { - Value cstOne = b.create(loc, b.getF32FloatAttr(1)); - Value inputSizeMOne = b.create(loc, inputSizeFP, cstOne); - nearestFP = b.create(loc, proj); - nearestFP = b.create(loc, nearestFP, inputSizeMOne); + Value cstOne = arith::ConstantOp::create(b, loc, b.getF32FloatAttr(1)); + Value inputSizeMOne = arith::SubFOp::create(b, loc, inputSizeFP, cstOne); + nearestFP = math::CeilOp::create(b, loc, proj); + nearestFP = arith::MinimumFOp::create(b, loc, nearestFP, inputSizeMOne); } else { llvm_unreachable("Unsupported nearest mode"); } Value nearestInt = - b.create(loc, b.getI64Type(), nearestFP); + arith::FPToSIOp::create(b, loc, b.getI64Type(), nearestFP); Value nearest = - b.create(loc, b.getIndexType(), nearestInt); + arith::IndexCastOp::create(b, loc, b.getIndexType(), nearestInt); indices[i] = nearest; } - Value retVal = b.create(loc, input, indices); + Value retVal = tensor::ExtractOp::create(b, loc, input, indices); return retVal; } @@ -2800,81 +2801,81 @@ static SmallVector coordinateTransform( auto inputType = cast(input.getType()); auto inputRank = inputType.getRank(); - Value cstOneFloat = b.create(loc, b.getF32FloatAttr(1.0)); - Value cstHalf = b.create(loc, b.getF32FloatAttr(0.5)); - Value zero = b.create(loc, b.getF32FloatAttr(0.0)); + Value cstOneFloat = arith::ConstantOp::create(b, loc, b.getF32FloatAttr(1.0)); + Value cstHalf = arith::ConstantOp::create(b, loc, b.getF32FloatAttr(0.5)); + Value zero = arith::ConstantOp::create(b, loc, b.getF32FloatAttr(0.0)); SmallVector proj; for (unsigned i = 0; i < inputRank - dimOffset; i++) { // length_original Value inputFP = - b.create(loc, b.getF32Type(), inputSizes[i]); + arith::SIToFPOp::create(b, loc, b.getF32Type(), inputSizes[i]); // length_resized Value outputSizeFP = - b.create(loc, b.getF32Type(), outputSizes[i]); + arith::SIToFPOp::create(b, loc, b.getF32Type(), outputSizes[i]); // scale = length_resized/length_original Value scale; if (alignCornersBool) { // x_original = x_resized * (length_original - 1) / (length_resized - 1) - Value inputSubOne = b.create(loc, inputFP, cstOneFloat); + Value inputSubOne = arith::SubFOp::create(b, loc, inputFP, cstOneFloat); Value outputSizeSubOne = - b.create(loc, outputSizeFP, cstOneFloat); - Value cmp = b.create(loc, arith::CmpFPredicate::UEQ, - outputSizeSubOne, zero); - scale = b.create(loc, inputSubOne, outputSizeSubOne); - scale = b.create(loc, cmp, zero, scale); + arith::SubFOp::create(b, loc, outputSizeFP, cstOneFloat); + Value cmp = arith::CmpFOp::create(b, loc, arith::CmpFPredicate::UEQ, + outputSizeSubOne, zero); + scale = arith::DivFOp::create(b, loc, inputSubOne, outputSizeSubOne); + scale = arith::SelectOp::create(b, loc, cmp, zero, scale); coordStr = "_align_corners"; } else if (scaleValues.empty()) - scale = b.create(loc, outputSizeFP, inputFP); + scale = arith::DivFOp::create(b, loc, outputSizeFP, inputFP); else scale = scaleValues[i]; // y_resized - Value outInt = b.create(loc, b.getI64Type(), - indices[i + dimOffset]); - Value outFP = b.create(loc, b.getF32Type(), outInt); + Value outInt = arith::IndexCastOp::create(b, loc, b.getI64Type(), + indices[i + dimOffset]); + Value outFP = arith::SIToFPOp::create(b, loc, b.getF32Type(), outInt); Value preClip; if (coordStr == "_align_corners") { - preClip = b.create(loc, outFP, scale); + preClip = arith::MulFOp::create(b, loc, outFP, scale); } if (coordStr == "_asymmetric") { - preClip = b.create(loc, outFP, scale); + preClip = arith::DivFOp::create(b, loc, outFP, scale); } if (coordStr == "_pytorch_half_pixel" || coordStr == "" || coordStr == "_half_pixel_symmetric") { // half-pixel modes // y_resized + 0.5 - Value outPlusHalf = b.create(loc, outFP, cstHalf); + Value outPlusHalf = arith::AddFOp::create(b, loc, outFP, cstHalf); // (y_resized + 0.5) / scale - Value outDivScale = b.create(loc, outPlusHalf, scale); + Value outDivScale = arith::DivFOp::create(b, loc, outPlusHalf, scale); // _ - 0.5 - preClip = b.create(loc, outDivScale, cstHalf); + preClip = arith::SubFOp::create(b, loc, outDivScale, cstHalf); } // for half_pixel_symmetric, need to compute offset from raw scales if (coordStr == "_half_pixel_symmetric" && !scaleValues.empty()) { - Value outputSizeFromScale = b.create(loc, inputFP, scale); + Value outputSizeFromScale = arith::MulFOp::create(b, loc, inputFP, scale); Value adjustment = - b.create(loc, outputSizeFP, outputSizeFromScale); - Value cstTwo = b.create(loc, b.getF32FloatAttr(2.0)); - Value center = b.create(loc, inputFP, cstTwo); + arith::DivFOp::create(b, loc, outputSizeFP, outputSizeFromScale); + Value cstTwo = arith::ConstantOp::create(b, loc, b.getF32FloatAttr(2.0)); + Value center = arith::DivFOp::create(b, loc, inputFP, cstTwo); Value oneMAdjustment = - b.create(loc, cstOneFloat, adjustment); - Value offset = b.create(loc, center, oneMAdjustment); - preClip = b.create(loc, offset, preClip); + arith::SubFOp::create(b, loc, cstOneFloat, adjustment); + Value offset = arith::MulFOp::create(b, loc, center, oneMAdjustment); + preClip = arith::AddFOp::create(b, loc, offset, preClip); } // for pytorch half pixel , special case for length_resized == 1: if (coordStr == "_pytorch_half_pixel") { - Value cmp = b.create(loc, arith::CmpFPredicate::UEQ, - outputSizeFP, cstOneFloat); - preClip = b.create(loc, cmp, zero, preClip); + Value cmp = arith::CmpFOp::create(b, loc, arith::CmpFPredicate::UEQ, + outputSizeFP, cstOneFloat); + preClip = arith::SelectOp::create(b, loc, cmp, zero, preClip); } if (clip) { // preClip is the fp position inside the input image to extract from. // clip to [0,inf) - Value max = b.create(loc, preClip, zero); - Value inputSubOne = b.create(loc, inputFP, cstOneFloat); + Value max = arith::MaximumFOp::create(b, loc, preClip, zero); + Value inputSubOne = arith::SubFOp::create(b, loc, inputFP, cstOneFloat); // clip to [0,length_original - 1]. // proj is properly within the input image. - proj.push_back(b.create(loc, max, inputSubOne)); + proj.push_back(arith::MinimumFOp::create(b, loc, max, inputSubOne)); } else { proj.push_back(preClip); } @@ -2892,14 +2893,14 @@ static Value bilinearInterpolate(OpBuilder &b, auto inputType = cast(input.getType()); auto inputRank = inputType.getRank(); - Value cstOneFloat = b.create(loc, b.getF32FloatAttr(1.0)); + Value cstOneFloat = arith::ConstantOp::create(b, loc, b.getF32FloatAttr(1.0)); bool alignCornersBool; matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool)); SmallVector indices; for (unsigned i = 0; i < inputRank; i++) { - indices.push_back(b.create(loc, i)); + indices.push_back(linalg::IndexOp::create(b, loc, i)); } SmallVector proj, high, low, highFP, lowFP; @@ -2909,44 +2910,44 @@ static Value bilinearInterpolate(OpBuilder &b, for (unsigned i = 0; i < inputRank - dimOffset; i++) { // length_original Value inputFP = - b.create(loc, b.getF32Type(), inputSizes[i]); - Value inputSubOne = b.create(loc, inputFP, cstOneFloat); + arith::SIToFPOp::create(b, loc, b.getF32Type(), inputSizes[i]); + Value inputSubOne = arith::SubFOp::create(b, loc, inputFP, cstOneFloat); // for bilinear interpolation, we look for the nearest indices below and // above proj - lowFP.push_back(b.create(loc, proj[i])); - Value projPlusOne = b.create(loc, cstOneFloat, proj[i]); - highFP.push_back(b.create(loc, projPlusOne)); + lowFP.push_back(math::FloorOp::create(b, loc, proj[i])); + Value projPlusOne = arith::AddFOp::create(b, loc, cstOneFloat, proj[i]); + highFP.push_back(math::FloorOp::create(b, loc, projPlusOne)); - Value lowInt = b.create(loc, b.getI64Type(), lowFP[i]); - low.push_back(b.create(loc, b.getIndexType(), lowInt)); + Value lowInt = arith::FPToSIOp::create(b, loc, b.getI64Type(), lowFP[i]); + low.push_back(arith::IndexCastOp::create(b, loc, b.getIndexType(), lowInt)); // highFP could be out-of-bounds, so make sure to clip it down before // extracting. If highFP actually gets clipped here, then high[i] will // extract at the last pixel, but will treat it as if it were extracted from // one further position when computing the interpolation weights. Value highExtract = - b.create(loc, projPlusOne, inputSubOne); - highExtract = b.create(loc, b.getI64Type(), highExtract); + arith::MinimumFOp::create(b, loc, projPlusOne, inputSubOne); + highExtract = arith::FPToSIOp::create(b, loc, b.getI64Type(), highExtract); high.push_back( - b.create(loc, b.getIndexType(), highExtract)); + arith::IndexCastOp::create(b, loc, b.getIndexType(), highExtract)); } indices[dimOffset] = low[0]; indices[dimOffset + 1] = low[1]; - Value p00 = b.create(loc, input, indices); + Value p00 = tensor::ExtractOp::create(b, loc, input, indices); indices[dimOffset] = low[0]; indices[dimOffset + 1] = high[1]; - Value p01 = b.create(loc, input, indices); + Value p01 = tensor::ExtractOp::create(b, loc, input, indices); indices[dimOffset] = high[0]; indices[dimOffset + 1] = low[1]; - Value p10 = b.create(loc, input, indices); + Value p10 = tensor::ExtractOp::create(b, loc, input, indices); indices[dimOffset] = high[0]; indices[dimOffset + 1] = high[1]; - Value p11 = b.create(loc, input, indices); + Value p11 = tensor::ExtractOp::create(b, loc, input, indices); // Let Aij := area rect((yProj,xProj) <-> (y_i*,x_j*)), // where i* = i+1 mod 2 and x_0 = xLow, x_1 = xHigh etc. @@ -2955,23 +2956,23 @@ static Value bilinearInterpolate(OpBuilder &b, // Note: we do not need to divide by total rect area == 1 // lengths : Aij == dyi*dxj - Value dy0 = b.create(loc, highFP[0], proj[0]); - Value dy1 = b.create(loc, proj[0], lowFP[0]); - Value dx0 = b.create(loc, highFP[1], proj[1]); - Value dx1 = b.create(loc, proj[1], lowFP[1]); + Value dy0 = arith::SubFOp::create(b, loc, highFP[0], proj[0]); + Value dy1 = arith::SubFOp::create(b, loc, proj[0], lowFP[0]); + Value dx0 = arith::SubFOp::create(b, loc, highFP[1], proj[1]); + Value dx1 = arith::SubFOp::create(b, loc, proj[1], lowFP[1]); // left = A00*p00 + A01*p01 = dy0(dx0p00 + dx1p01) - Value dx0p00 = b.create(loc, dx0, p00); - Value dx1p01 = b.create(loc, dx1, p01); - Value sum = b.create(loc, dx0p00, dx1p01); - Value left = b.create(loc, dy0, sum); + Value dx0p00 = arith::MulFOp::create(b, loc, dx0, p00); + Value dx1p01 = arith::MulFOp::create(b, loc, dx1, p01); + Value sum = arith::AddFOp::create(b, loc, dx0p00, dx1p01); + Value left = arith::MulFOp::create(b, loc, dy0, sum); // right = A10*p10 + A11*p11 = dy1(dx0p10 + dx1p11) - Value dx0p10 = b.create(loc, dx0, p10); - Value dx1p11 = b.create(loc, dx1, p11); - sum = b.create(loc, dx0p10, dx1p11); - Value right = b.create(loc, dy1, sum); + Value dx0p10 = arith::MulFOp::create(b, loc, dx0, p10); + Value dx1p11 = arith::MulFOp::create(b, loc, dx1, p11); + sum = arith::AddFOp::create(b, loc, dx0p10, dx1p11); + Value right = arith::MulFOp::create(b, loc, dy1, sum); - return b.create(loc, left, right); + return arith::AddFOp::create(b, loc, left, right); } static Value bicubicInterpolate(OpBuilder &b, @@ -2985,58 +2986,62 @@ static Value bicubicInterpolate(OpBuilder &b, auto inputRank = inputType.getRank(); Value inputFPH = - b.create(loc, b.getF32Type(), inputSizes[0]); + arith::SIToFPOp::create(b, loc, b.getF32Type(), inputSizes[0]); Value inputFPW = - b.create(loc, b.getF32Type(), inputSizes[1]); + arith::SIToFPOp::create(b, loc, b.getF32Type(), inputSizes[1]); - Value a = b.create(loc, b.getF32FloatAttr(-0.75)); - Value zero = b.create(loc, b.getF32FloatAttr(0.0)); - Value cstOneFloat = b.create(loc, b.getF32FloatAttr(1.0)); - Value cstTwoFloat = b.create(loc, b.getF32FloatAttr(2.0)); + Value a = arith::ConstantOp::create(b, loc, b.getF32FloatAttr(-0.75)); + Value zero = arith::ConstantOp::create(b, loc, b.getF32FloatAttr(0.0)); + Value cstOneFloat = arith::ConstantOp::create(b, loc, b.getF32FloatAttr(1.0)); + Value cstTwoFloat = arith::ConstantOp::create(b, loc, b.getF32FloatAttr(2.0)); Value cstThreeFloat = - b.create(loc, b.getF32FloatAttr(3.0)); - Value cstFourFloat = b.create(loc, b.getF32FloatAttr(4.0)); - Value cstFiveFloat = b.create(loc, b.getF32FloatAttr(5.0)); + arith::ConstantOp::create(b, loc, b.getF32FloatAttr(3.0)); + Value cstFourFloat = + arith::ConstantOp::create(b, loc, b.getF32FloatAttr(4.0)); + Value cstFiveFloat = + arith::ConstantOp::create(b, loc, b.getF32FloatAttr(5.0)); Value cstEightFloat = - b.create(loc, b.getF32FloatAttr(8.0)); + arith::ConstantOp::create(b, loc, b.getF32FloatAttr(8.0)); // (a+2)|x|^3 - (a+3)|x|^2 + 1 for xDistance (|x| <= 1) auto WeightLessThanEqualOne = [&](Value xDistance) -> Value { - Value xDistanceSquared = b.create(loc, xDistance, xDistance); + Value xDistanceSquared = + arith::MulFOp::create(b, loc, xDistance, xDistance); Value xDistanceCubed = - b.create(loc, xDistanceSquared, xDistance); + arith::MulFOp::create(b, loc, xDistanceSquared, xDistance); - Value lessEqualOne = b.create(loc, a, cstTwoFloat); - lessEqualOne = b.create(loc, xDistanceCubed, lessEqualOne); - Value aPlusThree = b.create(loc, a, cstThreeFloat); - aPlusThree = b.create(loc, xDistanceSquared, aPlusThree); - lessEqualOne = b.create(loc, lessEqualOne, aPlusThree); - lessEqualOne = b.create(loc, lessEqualOne, cstOneFloat); + Value lessEqualOne = arith::AddFOp::create(b, loc, a, cstTwoFloat); + lessEqualOne = arith::MulFOp::create(b, loc, xDistanceCubed, lessEqualOne); + Value aPlusThree = arith::AddFOp::create(b, loc, a, cstThreeFloat); + aPlusThree = arith::MulFOp::create(b, loc, xDistanceSquared, aPlusThree); + lessEqualOne = arith::SubFOp::create(b, loc, lessEqualOne, aPlusThree); + lessEqualOne = arith::AddFOp::create(b, loc, lessEqualOne, cstOneFloat); return lessEqualOne; }; // a|x|^3 - 5a|x|^2 + 8a|x| - 4a for xDistance (1 < |x| < 2) auto WeightLessThanTwo = [&](Value xDistance) -> Value { - Value xDistanceSquared = b.create(loc, xDistance, xDistance); + Value xDistanceSquared = + arith::MulFOp::create(b, loc, xDistance, xDistance); Value xDistanceCubed = - b.create(loc, xDistanceSquared, xDistance); + arith::MulFOp::create(b, loc, xDistanceSquared, xDistance); // a|x|^3 - Value lessThanTwo = b.create(loc, xDistanceCubed, a); + Value lessThanTwo = arith::MulFOp::create(b, loc, xDistanceCubed, a); - Value fiveA = b.create(loc, xDistanceSquared, a); - fiveA = b.create(loc, fiveA, cstFiveFloat); + Value fiveA = arith::MulFOp::create(b, loc, xDistanceSquared, a); + fiveA = arith::MulFOp::create(b, loc, fiveA, cstFiveFloat); // a|x|^3 - 5a|x|^2 - lessThanTwo = b.create(loc, lessThanTwo, fiveA); + lessThanTwo = arith::SubFOp::create(b, loc, lessThanTwo, fiveA); - Value eightA = b.create(loc, a, xDistance); - eightA = b.create(loc, eightA, cstEightFloat); + Value eightA = arith::MulFOp::create(b, loc, a, xDistance); + eightA = arith::MulFOp::create(b, loc, eightA, cstEightFloat); // a|x|^3 - 5a|x|^2 + 8a|x| - lessThanTwo = b.create(loc, eightA, lessThanTwo); + lessThanTwo = arith::AddFOp::create(b, loc, eightA, lessThanTwo); - Value fourA = b.create(loc, a, cstFourFloat); + Value fourA = arith::MulFOp::create(b, loc, a, cstFourFloat); // a|x|^3 - 5a|x|^2 + 8a|x| - 4a - lessThanTwo = b.create(loc, lessThanTwo, fourA); + lessThanTwo = arith::SubFOp::create(b, loc, lessThanTwo, fourA); return lessThanTwo; }; @@ -3045,7 +3050,7 @@ static Value bicubicInterpolate(OpBuilder &b, SmallVector indices; for (unsigned i = 0; i < inputRank; i++) { - indices.push_back(b.create(loc, i)); + indices.push_back(linalg::IndexOp::create(b, loc, i)); } SmallVector proj; @@ -3055,34 +3060,34 @@ static Value bicubicInterpolate(OpBuilder &b, false); // get the nearest neighbors of proj - Value x1 = b.create(loc, proj[1]); - Value x_1 = b.create(loc, x1, cstOneFloat); - Value x_2 = b.create(loc, x_1, cstOneFloat); - Value x2 = b.create(loc, x1, cstOneFloat); + Value x1 = math::CeilOp::create(b, loc, proj[1]); + Value x_1 = arith::SubFOp::create(b, loc, x1, cstOneFloat); + Value x_2 = arith::SubFOp::create(b, loc, x_1, cstOneFloat); + Value x2 = arith::AddFOp::create(b, loc, x1, cstOneFloat); - Value y1 = b.create(loc, proj[0]); - Value y_1 = b.create(loc, y1, cstOneFloat); - Value y_2 = b.create(loc, y_1, cstOneFloat); - Value y2 = b.create(loc, y1, cstOneFloat); + Value y1 = math::CeilOp::create(b, loc, proj[0]); + Value y_1 = arith::SubFOp::create(b, loc, y1, cstOneFloat); + Value y_2 = arith::SubFOp::create(b, loc, y_1, cstOneFloat); + Value y2 = arith::AddFOp::create(b, loc, y1, cstOneFloat); // calculate the distance of nearest neighbors x and y to proj - Value y2Distance = b.create(loc, proj[0], y2); - y2Distance = b.create(loc, y2Distance); - Value y1Distance = b.create(loc, proj[0], y1); - y1Distance = b.create(loc, y1Distance); - Value y_1Distance = b.create(loc, proj[0], y_1); - y_1Distance = b.create(loc, y_1Distance); - Value y_2Distance = b.create(loc, proj[0], y_2); - y_2Distance = b.create(loc, y_2Distance); - - Value x2Distance = b.create(loc, proj[1], x2); - x2Distance = b.create(loc, x2Distance); - Value x1Distance = b.create(loc, proj[1], x1); - x1Distance = b.create(loc, x1Distance); - Value x_1Distance = b.create(loc, proj[1], x_1); - x_1Distance = b.create(loc, x_1Distance); - Value x_2Distance = b.create(loc, proj[1], x_2); - x_2Distance = b.create(loc, x_2Distance); + Value y2Distance = arith::SubFOp::create(b, loc, proj[0], y2); + y2Distance = math::AbsFOp::create(b, loc, y2Distance); + Value y1Distance = arith::SubFOp::create(b, loc, proj[0], y1); + y1Distance = math::AbsFOp::create(b, loc, y1Distance); + Value y_1Distance = arith::SubFOp::create(b, loc, proj[0], y_1); + y_1Distance = math::AbsFOp::create(b, loc, y_1Distance); + Value y_2Distance = arith::SubFOp::create(b, loc, proj[0], y_2); + y_2Distance = math::AbsFOp::create(b, loc, y_2Distance); + + Value x2Distance = arith::SubFOp::create(b, loc, proj[1], x2); + x2Distance = math::AbsFOp::create(b, loc, x2Distance); + Value x1Distance = arith::SubFOp::create(b, loc, proj[1], x1); + x1Distance = math::AbsFOp::create(b, loc, x1Distance); + Value x_1Distance = arith::SubFOp::create(b, loc, proj[1], x_1); + x_1Distance = math::AbsFOp::create(b, loc, x_1Distance); + Value x_2Distance = arith::SubFOp::create(b, loc, proj[1], x_2); + x_2Distance = math::AbsFOp::create(b, loc, x_2Distance); SmallVector y{y_2, y_1, y1, y2}; SmallVector x{x_2, x_1, x1, x2}; @@ -3096,17 +3101,17 @@ static Value bicubicInterpolate(OpBuilder &b, // clip the nearest neighbors points to inside the original image for (int k = 0; k < 4; k++) { - Value yClipped = b.create(loc, y[k], zero); - Value inputHSubOne = b.create(loc, inputFPH, cstOneFloat); - yClipped = b.create(loc, yClipped, inputHSubOne); - Value yInt = b.create(loc, b.getI64Type(), yClipped); - y[k] = b.create(loc, b.getIndexType(), yInt); - - Value xClipped = b.create(loc, x[k], zero); - Value inputWSubOne = b.create(loc, inputFPW, cstOneFloat); - xClipped = b.create(loc, xClipped, inputWSubOne); - Value xInt = b.create(loc, b.getI64Type(), xClipped); - x[k] = b.create(loc, b.getIndexType(), xInt); + Value yClipped = arith::MaximumFOp::create(b, loc, y[k], zero); + Value inputHSubOne = arith::SubFOp::create(b, loc, inputFPH, cstOneFloat); + yClipped = arith::MinimumFOp::create(b, loc, yClipped, inputHSubOne); + Value yInt = arith::FPToSIOp::create(b, loc, b.getI64Type(), yClipped); + y[k] = arith::IndexCastOp::create(b, loc, b.getIndexType(), yInt); + + Value xClipped = arith::MaximumFOp::create(b, loc, x[k], zero); + Value inputWSubOne = arith::SubFOp::create(b, loc, inputFPW, cstOneFloat); + xClipped = arith::MinimumFOp::create(b, loc, xClipped, inputWSubOne); + Value xInt = arith::FPToSIOp::create(b, loc, b.getI64Type(), xClipped); + x[k] = arith::IndexCastOp::create(b, loc, b.getIndexType(), xInt); } // 1. Compute x_original and y_original (proj) // 2. Compute nearest x and y neighbors @@ -3132,13 +3137,13 @@ static Value bicubicInterpolate(OpBuilder &b, indices[dimOffset + 1] = x[i]; - Value p = b.create(loc, input, indices); + Value p = tensor::ExtractOp::create(b, loc, input, indices); - Value wxp = b.create(loc, wx, p); - xInterpy = b.create(loc, xInterpy, wxp); + Value wxp = arith::MulFOp::create(b, loc, wx, p); + xInterpy = arith::AddFOp::create(b, loc, xInterpy, wxp); } - Value wyXInterpy = b.create(loc, wy, xInterpy); - fxy = b.create(loc, fxy, wyXInterpy); + Value wyXInterpy = arith::MulFOp::create(b, loc, wy, xInterpy); + fxy = arith::AddFOp::create(b, loc, fxy, wyXInterpy); } return fxy; @@ -3178,8 +3183,8 @@ class ConvertInterpolateOp SmallVector ScaleFactorFloatValues; for (unsigned i = 2; i < inputRank; i++) { Value inputSize = getDimOp(rewriter, loc, input, i); - inputSizes.push_back(rewriter.create( - loc, rewriter.getIntegerType(64), inputSize)); + inputSizes.push_back(arith::IndexCastOp::create( + rewriter, loc, rewriter.getIntegerType(64), inputSize)); } if (!isa(op.getScaleFactor().getType())) { @@ -3195,15 +3200,15 @@ class ConvertInterpolateOp ScaleFactorFloatValues = getTypeConvertedValues( rewriter, loc, getTypeConverter(), ScaleFactorTorchFloat); for (unsigned i = 0; i < inputRank - 2; i++) { - Value inputSizeFP = rewriter.create( - loc, rewriter.getF32Type(), inputSizes[i]); - ScaleFactorFloatValues[i] = rewriter.create( - loc, inputSizeFP.getType(), ScaleFactorFloatValues[i]); - Value outputSize = rewriter.create( - loc, inputSizeFP, ScaleFactorFloatValues[i]); - outputSize = rewriter.create(loc, outputSize); - outputSize = rewriter.create( - loc, rewriter.getI64Type(), outputSize); + Value inputSizeFP = arith::SIToFPOp::create( + rewriter, loc, rewriter.getF32Type(), inputSizes[i]); + ScaleFactorFloatValues[i] = arith::TruncFOp::create( + rewriter, loc, inputSizeFP.getType(), ScaleFactorFloatValues[i]); + Value outputSize = arith::MulFOp::create(rewriter, loc, inputSizeFP, + ScaleFactorFloatValues[i]); + outputSize = math::FloorOp::create(rewriter, loc, outputSize); + outputSize = arith::FPToSIOp::create(rewriter, loc, + rewriter.getI64Type(), outputSize); outputSizeIntValues.push_back(outputSize); } if (recompScale) @@ -3222,41 +3227,38 @@ class ConvertInterpolateOp dims.push_back(castIntToIndex(rewriter, loc, outputSizeIntValues[i - 2])); } - Value outTensor = rewriter.create( - loc, getAsOpFoldResult(dims), inputType.getElementType()); + Value outTensor = tensor::EmptyOp::create( + rewriter, loc, getAsOpFoldResult(dims), inputType.getElementType()); AffineMap idMap = rewriter.getMultiDimIdentityMap(inputRank); SmallVector iteratorTypes( inputRank, utils::IteratorType::parallel); Value finalRes = - rewriter - .create( - loc, outTensor.getType(), ValueRange{}, outTensor, - /*indexingMaps=*/idMap, - /*iteratorTypes=*/iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value retVal; - if (mode.substr(0, 7) == "nearest") { - std::string coordTfMode = - mode.substr(7, mode.find(",") - 7); - std::string nearestMode = - (mode.find(",") == std::string::npos) - ? "" - : mode.substr(mode.find(",") + 1); - retVal = nearestInterpolate( - b, loc, outputSizeIntValues, input, inputSizes, - ScaleFactorFloatValues, coordTfMode, nearestMode); - } else if (mode.substr(0, 8) == "bilinear") { - retVal = bilinearInterpolate( - b, op, loc, outputSizeIntValues, input, inputSizes, - ScaleFactorFloatValues, mode.substr(8)); - } else if (mode.substr(0, 5) == "cubic") { - - retVal = bicubicInterpolate( - b, op, loc, outputSizeIntValues, input, inputSizes, - ScaleFactorFloatValues, mode.substr(5)); - } - b.create(loc, retVal); - }) + linalg::GenericOp::create( + rewriter, loc, outTensor.getType(), ValueRange{}, outTensor, + /*indexingMaps=*/idMap, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value retVal; + if (mode.substr(0, 7) == "nearest") { + std::string coordTfMode = mode.substr(7, mode.find(",") - 7); + std::string nearestMode = (mode.find(",") == std::string::npos) + ? "" + : mode.substr(mode.find(",") + 1); + retVal = nearestInterpolate(b, loc, outputSizeIntValues, input, + inputSizes, ScaleFactorFloatValues, + coordTfMode, nearestMode); + } else if (mode.substr(0, 8) == "bilinear") { + retVal = bilinearInterpolate( + b, op, loc, outputSizeIntValues, input, inputSizes, + ScaleFactorFloatValues, mode.substr(8)); + } else if (mode.substr(0, 5) == "cubic") { + + retVal = bicubicInterpolate( + b, op, loc, outputSizeIntValues, input, inputSizes, + ScaleFactorFloatValues, mode.substr(5)); + } + linalg::YieldOp::create(b, loc, retVal); + }) .getResult(0); Type newResultType = getTypeConverter()->convertType(op.getResult().getType()); @@ -3282,8 +3284,8 @@ class ConvertAtenLinalgDetOp : public OpConversionPattern { unsigned inputRank = inputType.getRank(); auto elemTy = inputType.getElementType(); bool isBatched = (inputRank == 3); - Value cstZero = rewriter.create(loc, 0); - Value cstOne = rewriter.create(loc, 1); + Value cstZero = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value cstOne = arith::ConstantIndexOp::create(rewriter, loc, 1); Value cstZeroF = getConstant(rewriter, loc, 0, elemTy); // get some shapes SmallVector inputShape(inputType.getShape()); @@ -3313,17 +3315,18 @@ class ConvertAtenLinalgDetOp : public OpConversionPattern { SmallVector inputSizes = getTensorSizes(rewriter, loc, input); Value chDim = isBatched ? inputSizes[0] : cstOne; Value matDim = inputSizes[inputRank - 1]; - Value matDimMinusOne = rewriter.create(loc, matDim, cstOne); + Value matDimMinusOne = arith::SubIOp::create(rewriter, loc, matDim, cstOne); ArrayRef sliceSizes(inputSizes.begin(), inputSizes.end() - 1); // initialize a tensor to store the diagonal elements found during row // reduction - Value initDiags = rewriter.create( - loc, getAsOpFoldResult(sliceSizes), elemTy); + Value initDiags = tensor::EmptyOp::create( + rewriter, loc, getAsOpFoldResult(sliceSizes), elemTy); // loop over each pivot row in A. Get the diagonal, then reduce the // subdiagonal Don't perform the loop on the last row since no further // reduction is needed. - auto rowReductionLoop = rewriter.create( - loc, /*start=*/cstZero, /*end=*/matDimMinusOne, /*step=*/cstOne, + auto rowReductionLoop = scf::ForOp::create( + rewriter, loc, /*start=*/cstZero, /*end=*/matDimMinusOne, + /*step=*/cstOne, /*yeild_to=*/ValueRange{input, initDiags}, /*body_lambda=*/ [&](OpBuilder &b, Location loc, Value row, ValueRange vals) { // extract row i from input Tensor of shape CxNxN or shape @@ -3336,17 +3339,17 @@ class ConvertAtenLinalgDetOp : public OpConversionPattern { auto sizes = getAsOpFoldResult(inputSizes); sizes[inputRank - 2] = cstOneFold; // offsets = [0, row, 0], sizes = [C, 1, N] -> pivot row - Value pivot = b.create( - loc, sliceTy, vals[0], offsets, sizes, strides); + Value pivot = tensor::ExtractSliceOp::create(b, loc, sliceTy, vals[0], + offsets, sizes, strides); // extract diagonal elements and insert them into vals[1] offsets.back() = row; sizes.back() = cstOneFold; // offsets = [0, row, row], sizes = [C, 1, 1] -> diag(row,row) - Value diag = b.create( - loc, diagTy, vals[0], offsets, sizes, strides); + Value diag = tensor::ExtractSliceOp::create(b, loc, diagTy, vals[0], + offsets, sizes, strides); - Value diagCollapse = b.create( - loc, diagCollapseTy, diag, diagReassociations); + Value diagCollapse = tensor::CollapseShapeOp::create( + b, loc, diagCollapseTy, diag, diagReassociations); SmallVector diagOffsets(inputRank - 1, cstZeroFold); diagOffsets.back() = row; @@ -3354,8 +3357,9 @@ class ConvertAtenLinalgDetOp : public OpConversionPattern { SmallVector diagSizes = getAsOpFoldResult(sliceSizes); diagSizes.back() = cstOneFold; // offsets = [0, row], sizes = [C, 1] insert to [C,N] - Value updatedDiags = b.create( - loc, diagCollapse, vals[1], diagOffsets, diagSizes, diagStrides); + Value updatedDiags = tensor::InsertSliceOp::create( + b, loc, diagCollapse, vals[1], diagOffsets, diagSizes, + diagStrides); // the subpivot matrix column size, as a Value, is matDim - row - // cstOne. This can't be statically converted to an int64_t, since row // is the loop index, so this is left as a dynamic dim. @@ -3365,21 +3369,21 @@ class ConvertAtenLinalgDetOp : public OpConversionPattern { subPivotShape.end() - 1); auto subPivotTy = RankedTensorType::get(subPivotShape, elemTy); auto subDiagTy = RankedTensorType::get(subDiagShape, elemTy); - Value rowPlusOne = b.create(loc, row, cstOne); + Value rowPlusOne = arith::AddIOp::create(b, loc, row, cstOne); offsets[inputRank - 2] = getAsOpFoldResult(rowPlusOne); sizes[inputRank - 2] = getAsOpFoldResult( - b.create(loc, matDim, rowPlusOne)); + arith::SubIOp::create(b, loc, matDim, rowPlusOne)); // offsets = [0, row + 1, row], sizes = [C, N - row - 1, 1] -> A_j,row // with j > row - Value subDiag = b.create( - loc, subDiagTy, vals[0], offsets, sizes, strides); + Value subDiag = tensor::ExtractSliceOp::create( + b, loc, subDiagTy, vals[0], offsets, sizes, strides); offsets.back() = cstZeroFold; sizes.back() = getAsOpFoldResult(matDim); // offsets = [0, row + 1, 0], sizes = [C, N - row - 1, N] -> elements // below pivot row - Value subPivot = b.create( - loc, subPivotTy, vals[0], offsets, sizes, strides); - Value initResult = b.create(loc, sizes, elemTy); + Value subPivot = tensor::ExtractSliceOp::create( + b, loc, subPivotTy, vals[0], offsets, sizes, strides); + Value initResult = tensor::EmptyOp::create(b, loc, sizes, elemTy); // write a generic op to perform subpivot = subpivot - // (subdiag/diag)*pivot // d0 = batches, d1 = row, d2 = column -> pivot(d0,d2), diag(d0), @@ -3416,40 +3420,40 @@ class ConvertAtenLinalgDetOp : public OpConversionPattern { SmallVector iteratorTypes( inputRank, utils::IteratorType::parallel); Value reducedSubPivot = - b.create( - loc, subPivotTy, ValueRange{pivot, diag, subPivot, subDiag}, - initResult, indexingMaps, iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - // for d0 in batches, d1 in subpivotrows, d2 in columns - // let i represent the pivot row index (scf loop index) - Value pivotd0d2 = args[0]; - Value diagd0 = args[1]; - Value subPivotd0d1d2 = args[2]; - Value subDiagd0d1 = args[3]; - // coeff = A_d1,i / A_i,i - Value coeff = - b.create(loc, subDiagd0d1, diagd0); - auto cmp = b.create( - loc, arith::CmpFPredicate::ONE, diagd0, cstZeroF); - b.create( - loc, cmp, - b.getStringAttr( - "unimplemented: determinants requiring " - "permutations and singular matrices")); - // coeff*A_i,d2 - Value scaledPivotValue = - b.create(loc, coeff, pivotd0d2); - // result = A_d1,d2 - (A_d1,i/A_i,i)*A_i,d2 - // so that when d2 = i, A_d1,i - (A_d1,i/A_i,i) * A_i,i = 0 - Value result = b.create(loc, subPivotd0d1d2, - scaledPivotValue); - b.create(loc, result); - }) + linalg::GenericOp::create( + b, loc, subPivotTy, + ValueRange{pivot, diag, subPivot, subDiag}, initResult, + indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + // for d0 in batches, d1 in subpivotrows, d2 in columns + // let i represent the pivot row index (scf loop index) + Value pivotd0d2 = args[0]; + Value diagd0 = args[1]; + Value subPivotd0d1d2 = args[2]; + Value subDiagd0d1 = args[3]; + // coeff = A_d1,i / A_i,i + Value coeff = + arith::DivFOp::create(b, loc, subDiagd0d1, diagd0); + auto cmp = arith::CmpFOp::create( + b, loc, arith::CmpFPredicate::ONE, diagd0, cstZeroF); + cf::AssertOp::create( + b, loc, cmp, + b.getStringAttr("unimplemented: determinants requiring " + "permutations and singular matrices")); + // coeff*A_i,d2 + Value scaledPivotValue = + arith::MulFOp::create(b, loc, coeff, pivotd0d2); + // result = A_d1,d2 - (A_d1,i/A_i,i)*A_i,d2 + // so that when d2 = i, A_d1,i - (A_d1,i/A_i,i) * A_i,i = 0 + Value result = arith::SubFOp::create(b, loc, subPivotd0d1d2, + scaledPivotValue); + linalg::YieldOp::create(b, loc, result); + }) .getResult(0); - Value rowReductionResult = b.create( - loc, reducedSubPivot, vals[0], offsets, sizes, strides); - b.create(loc, - ValueRange{rowReductionResult, updatedDiags}); + Value rowReductionResult = tensor::InsertSliceOp::create( + b, loc, reducedSubPivot, vals[0], offsets, sizes, strides); + scf::YieldOp::create(b, loc, + ValueRange{rowReductionResult, updatedDiags}); }); Value allDiagsExceptLast = rowReductionLoop.getResult(1); SmallVector offsets(inputRank, @@ -3459,17 +3463,18 @@ class ConvertAtenLinalgDetOp : public OpConversionPattern { sizes[0] = getAsOpFoldResult(chDim); if (isBatched) offsets[0] = getAsOpFoldResult(cstZero); - Value lastDiag = rewriter.create( - loc, diagTy, rowReductionLoop.getResult(0), offsets, sizes, strides); + Value lastDiag = tensor::ExtractSliceOp::create( + rewriter, loc, diagTy, rowReductionLoop.getResult(0), offsets, sizes, + strides); offsets.pop_back(); strides.pop_back(); sizes.pop_back(); - lastDiag = rewriter.create( - loc, diagCollapseTy, lastDiag, diagReassociations); + lastDiag = tensor::CollapseShapeOp::create(rewriter, loc, diagCollapseTy, + lastDiag, diagReassociations); - Value allDiags = rewriter.create( - loc, lastDiag, allDiagsExceptLast, offsets, sizes, strides); + Value allDiags = tensor::InsertSliceOp::create( + rewriter, loc, lastDiag, allDiagsExceptLast, offsets, sizes, strides); // linalg generic to do reduce prod for allDiags along back dim. // the result of that generic will be the determinant SmallVector indexingMaps; @@ -3483,14 +3488,13 @@ class ConvertAtenLinalgDetOp : public OpConversionPattern { Value initDet = createInitTensor(rewriter, loc, ValueRange{chDim}, elemTy, getConstant(rewriter, loc, 1.0, elemTy)); Value determinant = - rewriter - .create( - loc, initDet.getType(), ValueRange{allDiags}, initDet, - indexingMaps, iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value prod = b.create(loc, args[0], args[1]); - b.create(loc, prod); - }) + linalg::GenericOp::create( + rewriter, loc, initDet.getType(), ValueRange{allDiags}, initDet, + indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value prod = arith::MulFOp::create(b, loc, args[0], args[1]); + linalg::YieldOp::create(b, loc, prod); + }) .getResult(0); Type newResultType = getTypeConverter()->convertType(op.getResult().getType()); @@ -3500,8 +3504,8 @@ class ConvertAtenLinalgDetOp : public OpConversionPattern { return success(); } - determinant = rewriter.create( - loc, newResultType, determinant, + determinant = tensor::CollapseShapeOp::create( + rewriter, loc, newResultType, determinant, llvm::ArrayRef{}); rewriter.replaceOp(op, ValueRange{determinant}); return success(); @@ -3533,12 +3537,12 @@ class ConvertAtenPolarOp : public OpConversionPattern { SmallVector resultShape; for (int64_t i = 0; i < resultType.getRank(); i++) { - auto currentDimSize = rewriter.create(loc, absTensor, i); + auto currentDimSize = tensor::DimOp::create(rewriter, loc, absTensor, i); resultShape.push_back(currentDimSize); } - Value outTensor = rewriter.create( - loc, getAsOpFoldResult(resultShape), elementType); + Value outTensor = tensor::EmptyOp::create( + rewriter, loc, getAsOpFoldResult(resultShape), elementType); SmallVector outputExpr; for (unsigned i = 0; i < resultType.getRank(); i++) { @@ -3552,22 +3556,22 @@ class ConvertAtenPolarOp : public OpConversionPattern { SmallVector iteratorTypes( resultType.getRank(), utils::IteratorType::parallel); auto complexVar = - rewriter - .create( - loc, outTensor.getType(), ValueRange{absTensor, angleTensor}, - outTensor, indexingMaps, iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - // out = absâ‹…cos(angle) + absâ‹…sin(angle)â‹…j - Value abs = args[0]; - Value angle = args[1]; - Value realVal = b.create(loc, angle); - Value imagVal = b.create(loc, angle); - realVal = b.create(loc, abs, realVal); - imagVal = b.create(loc, abs, imagVal); - Value complexVal = b.create( - loc, elementType, realVal, imagVal); - b.create(loc, complexVal); - }) + linalg::GenericOp::create( + rewriter, loc, outTensor.getType(), + ValueRange{absTensor, angleTensor}, outTensor, indexingMaps, + iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + // out = absâ‹…cos(angle) + absâ‹…sin(angle)â‹…j + Value abs = args[0]; + Value angle = args[1]; + Value realVal = math::CosOp::create(b, loc, angle); + Value imagVal = math::SinOp::create(b, loc, angle); + realVal = arith::MulFOp::create(b, loc, abs, realVal); + imagVal = arith::MulFOp::create(b, loc, abs, imagVal); + Value complexVal = complex::CreateOp::create(b, loc, elementType, + realVal, imagVal); + linalg::YieldOp::create(b, loc, complexVal); + }) .getResult(0); rewriter.replaceOpWithNewOp(op, resultType, complexVar); return success(); @@ -3619,17 +3623,17 @@ class ConvertSymConstrainRangeOp // FIXME:: Skip the below checks if constraint ops are already inserted as // part of symbol expr evaluation - auto checkMin = rewriter.create( - loc, arith::CmpIPredicate::sle, min, adaptor.getSize()); - auto checkMax = rewriter.create( - loc, arith::CmpIPredicate::sle, adaptor.getSize(), max); - auto compareVal = rewriter.create(loc, checkMin, checkMax); + auto checkMin = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sle, min, adaptor.getSize()); + auto checkMax = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sle, adaptor.getSize(), max); + auto compareVal = arith::AndIOp::create(rewriter, loc, checkMin, checkMax); std::string assertMessage = "Size constraint failed. Expected range: [" + std::to_string(minValue) + ", " + std::to_string(maxValue) + "]"; - rewriter.create(loc, compareVal, - rewriter.getStringAttr(assertMessage)); + cf::AssertOp::create(rewriter, loc, compareVal, + rewriter.getStringAttr(assertMessage)); rewriter.eraseOp(op); return success(); @@ -3846,24 +3850,24 @@ class ConvertOnnxVariantRotaryEmbeddingOp SmallVector resultShape; for (int64_t i = 0; i < inputRank; i++) { - auto currentDimSize = rewriter.create(loc, input, i); + auto currentDimSize = tensor::DimOp::create(rewriter, loc, input, i); resultShape.push_back(currentDimSize); } Value outTensor = createZeroInitTensor(rewriter, loc, resultShape, elementType); - Value cstFloatOne = rewriter.create( - loc, rewriter.getFloatAttr(elementType, 1.0)); - Value cstFloatMinusOne = rewriter.create( - loc, rewriter.getFloatAttr(elementType, -1.0)); + Value cstFloatOne = arith::ConstantOp::create( + rewriter, loc, rewriter.getFloatAttr(elementType, 1.0)); + Value cstFloatMinusOne = arith::ConstantOp::create( + rewriter, loc, rewriter.getFloatAttr(elementType, -1.0)); Value cstIndexTwo = - rewriter.create(loc, rewriter.getIndexAttr(2)); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(2)); Value cstIndexOne = - rewriter.create(loc, rewriter.getIndexAttr(1)); - Value cstRotaryEmbDim = rewriter.create( - loc, rewriter.getIndexAttr(rotaryEmbeddingDim)); - Value cstHalfRotaryEmbDim = rewriter.create( - loc, rewriter.getIndexAttr(halfRotaryEmbDim)); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(1)); + Value cstRotaryEmbDim = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexAttr(rotaryEmbeddingDim)); + Value cstHalfRotaryEmbDim = arith::ConstantOp::create( + rewriter, loc, rewriter.getIndexAttr(halfRotaryEmbDim)); AffineMap identityMap = AffineMap::getMultiDimIdentityMap(inputRank, context); @@ -3875,85 +3879,84 @@ class ConvertOnnxVariantRotaryEmbeddingOp inputRank, utils::IteratorType::parallel); auto rotaryEmbedding = - rewriter - .create( - loc, outTensor.getType(), ValueRange{input, positionIds}, - outTensor, indexingMaps, iteratorTypes, - [&](OpBuilder &builder, Location loc, ValueRange args) { - // This linalg.generic will be iterating over the 4 dimensions - // of the input "b, n, s, h", respectively. - // - // if (interleaved): - // cache_idx = (h / 2) % half_rotary_emb_dim - // sign = h & 1 - // j = sign ? h - 1: h + 1 - // else: - // cache_idx = h % half_rotary_emb_dim - // sign = (h >= rotary_emb_dim) - // j = (h + half_rotary_emb_dim) % rotary_emb_dim - // - // orig_input = input[b][n][s][h] - // rotated_input = input[b][n][s][j] - // position_id = position_ids[b][s] - // cos_emb = cos_cache[position_id][cache_idx] - // sin_emb = sin_cache[position_id][cache_idx] - // out[b][n][s][h] = orig_input * cos_emb - // + - // (rotated_input * sin_emb) * sign - - Value b = builder.create(loc, 0); - Value n = builder.create(loc, 1); - Value s = builder.create(loc, 2); - Value h = builder.create(loc, 3); - - Value cacheIdx, sign, rotatedInputLastIdx; - if (interleaved) { - cacheIdx = - builder.create(loc, h, cstIndexTwo); - cacheIdx = builder.create( - loc, cacheIdx, cstHalfRotaryEmbDim); - sign = builder.create(loc, h, cstIndexOne); - // Converting sign value from index type to bool type. - sign = builder.create( - loc, rewriter.getI1Type(), sign); - rotatedInputLastIdx = builder.create( - loc, sign, - builder.create(loc, h, cstIndexOne), - builder.create(loc, h, cstIndexOne)); - } else { - cacheIdx = builder.create( - loc, h, cstHalfRotaryEmbDim); - sign = builder.create( - loc, arith::CmpIPredicate::sge, h, cstHalfRotaryEmbDim); - rotatedInputLastIdx = builder.create( - loc, h, cstHalfRotaryEmbDim); - rotatedInputLastIdx = builder.create( - loc, rotatedInputLastIdx, cstRotaryEmbDim); - } - - Value positionId = castIntToIndex(builder, loc, args[1]); - Value cosEmb = builder.create( - loc, cosCache, ValueRange{positionId, cacheIdx}); - Value sinEmb = builder.create( - loc, sinCache, ValueRange{positionId, cacheIdx}); - - Value origInput = args[0]; - Value rotatedInput = builder.create( - loc, input, ValueRange{b, n, s, rotatedInputLastIdx}); - - Value signMultiplier = builder.create( - loc, sign, cstFloatOne, cstFloatMinusOne); - - Value outputI = - builder.create(loc, origInput, cosEmb); - Value outputJ = - builder.create(loc, rotatedInput, sinEmb); - outputJ = builder.create(loc, outputJ, - signMultiplier); - Value out = - builder.create(loc, outputI, outputJ); - builder.create(loc, out); - }) + linalg::GenericOp::create( + rewriter, loc, outTensor.getType(), ValueRange{input, positionIds}, + outTensor, indexingMaps, iteratorTypes, + [&](OpBuilder &builder, Location loc, ValueRange args) { + // This linalg.generic will be iterating over the 4 dimensions + // of the input "b, n, s, h", respectively. + // + // if (interleaved): + // cache_idx = (h / 2) % half_rotary_emb_dim + // sign = h & 1 + // j = sign ? h - 1: h + 1 + // else: + // cache_idx = h % half_rotary_emb_dim + // sign = (h >= rotary_emb_dim) + // j = (h + half_rotary_emb_dim) % rotary_emb_dim + // + // orig_input = input[b][n][s][h] + // rotated_input = input[b][n][s][j] + // position_id = position_ids[b][s] + // cos_emb = cos_cache[position_id][cache_idx] + // sin_emb = sin_cache[position_id][cache_idx] + // out[b][n][s][h] = orig_input * cos_emb + // + + // (rotated_input * sin_emb) * sign + + Value b = linalg::IndexOp::create(builder, loc, 0); + Value n = linalg::IndexOp::create(builder, loc, 1); + Value s = linalg::IndexOp::create(builder, loc, 2); + Value h = linalg::IndexOp::create(builder, loc, 3); + + Value cacheIdx, sign, rotatedInputLastIdx; + if (interleaved) { + cacheIdx = arith::DivSIOp::create(builder, loc, h, cstIndexTwo); + cacheIdx = arith::RemSIOp::create(builder, loc, cacheIdx, + cstHalfRotaryEmbDim); + sign = arith::AndIOp::create(builder, loc, h, cstIndexOne); + // Converting sign value from index type to bool type. + sign = arith::TruncIOp::create(builder, loc, + rewriter.getI1Type(), sign); + rotatedInputLastIdx = arith::SelectOp::create( + builder, loc, sign, + arith::SubIOp::create(builder, loc, h, cstIndexOne), + arith::AddIOp::create(builder, loc, h, cstIndexOne)); + } else { + cacheIdx = arith::RemSIOp::create(builder, loc, h, + cstHalfRotaryEmbDim); + sign = arith::CmpIOp::create(builder, loc, + arith::CmpIPredicate::sge, h, + cstHalfRotaryEmbDim); + rotatedInputLastIdx = + arith::AddIOp::create(builder, loc, h, cstHalfRotaryEmbDim); + rotatedInputLastIdx = arith::RemSIOp::create( + builder, loc, rotatedInputLastIdx, cstRotaryEmbDim); + } + + Value positionId = castIntToIndex(builder, loc, args[1]); + Value cosEmb = tensor::ExtractOp::create( + builder, loc, cosCache, ValueRange{positionId, cacheIdx}); + Value sinEmb = tensor::ExtractOp::create( + builder, loc, sinCache, ValueRange{positionId, cacheIdx}); + + Value origInput = args[0]; + Value rotatedInput = tensor::ExtractOp::create( + builder, loc, input, + ValueRange{b, n, s, rotatedInputLastIdx}); + + Value signMultiplier = arith::SelectOp::create( + builder, loc, sign, cstFloatOne, cstFloatMinusOne); + + Value outputI = + arith::MulFOp::create(builder, loc, origInput, cosEmb); + Value outputJ = + arith::MulFOp::create(builder, loc, rotatedInput, sinEmb); + outputJ = + arith::MulFOp::create(builder, loc, outputJ, signMultiplier); + Value out = arith::AddFOp::create(builder, loc, outputI, outputJ); + linalg::YieldOp::create(builder, loc, out); + }) .getResult(0); rewriter.replaceOpWithNewOp(op, resultType, diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index e98ad5dca084..c2b584efdecc 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -40,9 +40,9 @@ Value torch_to_linalg::getPaddedTensor( getIndexIntsAsOpFoldResult(b, lowPaddingInts); SmallVector highPaddings = getIndexIntsAsOpFoldResult(b, highPaddingInts); - Value paddedInput = - b.create(loc, rankedTensorType, input, /*low=*/lowPaddings, - /*high=*/highPaddings, pad); + Value paddedInput = tensor::PadOp::create(b, loc, rankedTensorType, input, + /*low=*/lowPaddings, + /*high=*/highPaddings, pad); return paddedInput; } @@ -55,8 +55,8 @@ Value torch_to_linalg::getZeroPaddedTensor( assert(isa(input.getType()) && "input must be RankedTensorType"); Location loc = op->getLoc(); - Value c0 = b.create( - loc, + Value c0 = arith::ConstantOp::create( + b, loc, b.getZeroAttr(cast(input.getType()).getElementType())); return getPaddedTensor(op, b, input, paddingInts, paddingInts, c0); } @@ -72,7 +72,7 @@ Value torch_to_linalg::getDynamicZeroPaddedTensor( Location loc = op->getLoc(); SmallVector inputDims = getTensorSizes(b, loc, input); - Value c0 = b.create(loc, b.getI64IntegerAttr(0)); + Value c0 = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(0)); SmallVector paddingIncludingUnchanged(unpaddedDims, c0); paddingIncludingUnchanged.append(padding); assert(static_cast(unpaddedDims + padding.size()) == @@ -85,8 +85,8 @@ Value torch_to_linalg::getDynamicZeroPaddedTensor( SmallVector paddingValues = getAsOpFoldResult(paddingIncludingUnchanged); - return b.create(loc, Type{}, input, /*low=*/paddingValues, - /*high=*/paddingValues, pad); + return tensor::PadOp::create(b, loc, Type{}, input, /*low=*/paddingValues, + /*high=*/paddingValues, pad); } Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc, @@ -94,8 +94,8 @@ Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc, Value dilationInt, Value kernelSizeInt, Value strideInt, bool ceilMode) { - Value c1 = b.create(loc, b.getI64IntegerAttr(1)); - Value c2 = b.create(loc, b.getI64IntegerAttr(2)); + Value c1 = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(1)); + Value c2 = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(2)); Value doublePadding = b.createOrFold(loc, paddingInt, c2); // in + 2 * padding @@ -141,9 +141,9 @@ Value torch_to_linalg::getOutputDimForPoolOps(OpBuilder &b, Location loc, Value dilationInt, Value kernelSizeInt, Value strideInt, bool ceilMode) { - Value c1 = b.create(loc, b.getI64IntegerAttr(1)); + Value c1 = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(1)); Value totalPaddingIntCst = - b.create(loc, b.getI64IntegerAttr(totalPadding)); + arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(totalPadding)); // in + totalPadding Value inAddTotalPadding = b.createOrFold( @@ -170,7 +170,7 @@ Value torch_to_linalg::getOutputDimForPoolOps(OpBuilder &b, Location loc, Value outMinusOneTimesStride = b.createOrFold(loc, division, strideInt); Value leftPaddingIntCst = - b.create(loc, b.getI64IntegerAttr(leftPadding)); + arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(leftPadding)); Value inAddLeftPadding = b.createOrFold( loc, castIndexToInt64(b, loc, in), leftPaddingIntCst); @@ -185,25 +185,25 @@ Value torch_to_linalg::getOutputDimForPoolOps(OpBuilder &b, Location loc, Value torch_to_linalg::getOutputDimForConvTransposeOps( OpBuilder &b, Location loc, Value in, Value paddingInt, Value dilationInt, Value kernelSizeInt, Value strideInt, Value outputPaddingInt) { - Value c1 = b.create(loc, b.getI64IntegerAttr(1)); - Value c2 = b.create(loc, b.getI64IntegerAttr(2)); + Value c1 = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(1)); + Value c2 = arith::ConstantOp::create(b, loc, b.getI64IntegerAttr(2)); // (in - 1) * stride Value inStrided = - b.create(loc, castIndexToInt64(b, loc, in), c1); - inStrided = b.create(loc, inStrided, strideInt); + arith::SubIOp::create(b, loc, castIndexToInt64(b, loc, in), c1); + inStrided = arith::MulIOp::create(b, loc, inStrided, strideInt); // 2 * padding - Value doublePadding = b.create(loc, paddingInt, c2); + Value doublePadding = arith::MulIOp::create(b, loc, paddingInt, c2); // (kernelSize - 1) * dilation - Value kernelDilated = b.create(loc, kernelSizeInt, c1); - kernelDilated = b.create(loc, kernelDilated, dilationInt); + Value kernelDilated = arith::SubIOp::create(b, loc, kernelSizeInt, c1); + kernelDilated = arith::MulIOp::create(b, loc, kernelDilated, dilationInt); - Value out = b.create(loc, inStrided, doublePadding); - out = b.create(loc, out, kernelDilated); - out = b.create(loc, out, outputPaddingInt); - out = b.create(loc, out, c1); + Value out = arith::SubIOp::create(b, loc, inStrided, doublePadding); + out = arith::AddIOp::create(b, loc, out, kernelDilated); + out = arith::AddIOp::create(b, loc, out, outputPaddingInt); + out = arith::AddIOp::create(b, loc, out, c1); return castIntToIndex(b, loc, out); } @@ -218,10 +218,11 @@ Value torch_to_linalg::createReductionLinalgGeneric( // If `opInfo.keepDim` is true, the rank of the output tensor // is kept the same as the rank of the input tensor, and the // reduced dimensions are set to have size 1. - auto c1 = b.create(loc, /*value=*/1); + auto c1 = arith::ConstantIndexOp::create(b, loc, /*value=*/1); SmallVector resultShape; for (int64_t i = 0; i < inputType.getRank(); i++) { - auto currentDimSize = b.create(loc, opInfo.tensorOperand, i); + auto currentDimSize = + tensor::DimOp::create(b, loc, opInfo.tensorOperand, i); if (!opInfo.dimSet.contains(i)) resultShape.push_back(currentDimSize); else if (opInfo.keepDim) @@ -255,11 +256,10 @@ Value torch_to_linalg::createReductionLinalgGeneric( Value accumulator = createInitTensor(b, loc, resultShape, initElem.getType(), initElem); - return b - .create( - loc, /*resultTensorTypes=*/accumulator.getType(), - /*inputs=*/opInfo.tensorOperand, - /*outputs=*/accumulator, indexingMaps, iteratorTypes, bodyBuild) + return linalg::GenericOp::create( + b, loc, /*resultTensorTypes=*/accumulator.getType(), + /*inputs=*/opInfo.tensorOperand, + /*outputs=*/accumulator, indexingMaps, iteratorTypes, bodyBuild) .getResult(0); } @@ -300,7 +300,7 @@ Value torch_to_linalg::createElementwiseLinalgGeneric( // Initialize the resultShape to all 1's, as a fallback in case // all sizes along that result dimension are statically 1. - auto c1 = b.create(loc, /*value=*/1); + auto c1 = arith::ConstantIndexOp::create(b, loc, /*value=*/1); SmallVector resultShape(resultRank, c1); // Record whether or not all corresponding input dims are statically 1. @@ -370,10 +370,10 @@ Value torch_to_linalg::createElementwiseLinalgGeneric( // dimensions sizes that are expected to match. if (!elideDynamicBroadcastCheck) { auto equalToRunning = - b.create(loc, arith::CmpIPredicate::eq, - resultShape[resultDim], currentDimSize); - b.create(loc, equalToRunning, - "mismatched size for broadcast"); + arith::CmpIOp::create(b, loc, arith::CmpIPredicate::eq, + resultShape[resultDim], currentDimSize); + cf::AssertOp::create(b, loc, equalToRunning, + "mismatched size for broadcast"); } } indexingMaps.push_back(AffineMap::get( @@ -385,14 +385,13 @@ Value torch_to_linalg::createElementwiseLinalgGeneric( // Add the indexing map for the outs init tensor. indexingMaps.push_back(b.getMultiDimIdentityMap(resultRank)); - Value initTensor = b.create( - loc, getAsOpFoldResult(resultShape), resultElementType); - return b - .create(loc, - /*resultTensorTypes=*/initTensor.getType(), - /*inputs=*/tensorOperands, - /*outputs=*/initTensor, indexingMaps, - iteratorTypes, bodyBuild) + Value initTensor = tensor::EmptyOp::create( + b, loc, getAsOpFoldResult(resultShape), resultElementType); + return linalg::GenericOp::create(b, loc, + /*resultTensorTypes=*/initTensor.getType(), + /*inputs=*/tensorOperands, + /*outputs=*/initTensor, indexingMaps, + iteratorTypes, bodyBuild) .getResult(0); } @@ -424,11 +423,11 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( // Create affine map and shapes for tensor initialization. SmallVector outExpr; Value zero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + arith::ConstantOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0)); Value zeroIndex = - rewriter.create(loc, rewriter.getIndexAttr(0)); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(0)); Value oneIndex = - rewriter.create(loc, rewriter.getIndexAttr(1)); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(1)); size_t diff = outputRank - inputRank; bool hasDynamicNumpyBroadcast = false; for (size_t i = 0, e = outputRank; i < e; i++) { @@ -464,10 +463,10 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( if (i < diff) { if (!elideDynamicBroadcastCheck) { - Value isValid = rewriter.create( - loc, arith::CmpIPredicate::sge, shapeValue, zero); - rewriter.create( - loc, isValid, + Value isValid = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sge, shapeValue, zero); + cf::AssertOp::create( + rewriter, loc, isValid, rewriter.getStringAttr( "negative values not allowed in new dimensions")); } @@ -476,10 +475,11 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( } if (inputShape[j] == 1) { // Broadcast singleton dimension - Value isNegative = rewriter.create( - loc, arith::CmpIPredicate::slt, shapeValue, zero); - Value select = rewriter.create( - loc, isNegative, oneIndex, castIntToIndex(rewriter, loc, shapeValue)); + Value isNegative = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::slt, shapeValue, zero); + Value select = + arith::SelectOp::create(rewriter, loc, isNegative, oneIndex, + castIntToIndex(rewriter, loc, shapeValue)); outShape.push_back(select); broadcastedStatus.push_back(true); continue; @@ -494,10 +494,10 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( hasDynamicNumpyBroadcast = true; } if (!elideDynamicBroadcastCheck) { - Value isValid = rewriter.create( - loc, arith::CmpIPredicate::sge, shapeValue, zero); - rewriter.create( - loc, isValid, + Value isValid = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sge, shapeValue, zero); + cf::AssertOp::create( + rewriter, loc, isValid, rewriter.getStringAttr( "unimplemented: dynamic negative broadcast sizes")); } @@ -511,7 +511,7 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( } Value outTensor = - rewriter.create(loc, outShape, elementType); + tensor::EmptyOp::create(rewriter, loc, outShape, elementType); // If we know there are no ? -> ? broadcasted dims, or we are assuming // strict symbols, we can safely use standard linalg style broadcasting @@ -521,7 +521,8 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( // the op away entirely. if (!llvm::any_of(broadcastedStatus, [](bool b) { return b; }) && inputRank == outputRank) { - result = rewriter.create(loc, outTensor.getType(), input); + result = + tensor::CastOp::create(rewriter, loc, outTensor.getType(), input); return success(); } @@ -554,8 +555,8 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( } if (collapse) { - input = rewriter.create(op->getLoc(), input, - collapseExprs); + input = tensor::CollapseShapeOp::create(rewriter, op->getLoc(), input, + collapseExprs); } SmallVector indexingMaps = { @@ -563,13 +564,12 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( rewriter.getMultiDimIdentityMap(outputRank)}; SmallVector iteratorTypes( outputRank, utils::IteratorType::parallel); - result = rewriter - .create( - loc, outTensor.getType(), input, outTensor, indexingMaps, - iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - b.create(loc, args[0]); - }) + result = linalg::GenericOp::create( + rewriter, loc, outTensor.getType(), input, outTensor, + indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + linalg::YieldOp::create(b, loc, args[0]); + }) .getResult(0); return success(); } @@ -580,42 +580,41 @@ LogicalResult torch_to_linalg::broadcastToGivenShape( rewriter.getMultiDimIdentityMap(outputRank)}; SmallVector iteratorTypes(outputRank, utils::IteratorType::parallel); - result = rewriter - .create( - loc, outTensor.getType(), ValueRange(), outTensor, - indexingMaps, iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - // `loopIndices` contains IV of the linalg loops which - // would be used to extract values from the input tensor - // later on. - SmallVector loopIndices; - for (size_t i = 0, e = outputRank; i < e; ++i) { - if (i < diff) - continue; - loopIndices.push_back(b.create(loc, i)); - } - // `inputIndicesToExtract` contains i-th linalg loop IV if - // the i-th input dimension is not 1, else it contains a - // zero index. - SmallVector inputIndicesToExtract; - for (size_t i = 0, n = inputRank; i < n; i++) { - if (inputShape[i] == 1) { - inputIndicesToExtract.push_back(zeroIndex); - } else { - Value inputDim = getDimOp(b, loc, input, i); - Value isEqual = b.create( - loc, arith::CmpIPredicate::eq, inputDim, oneIndex); - Value select = rewriter.create( - loc, isEqual, zeroIndex, loopIndices[i]); - inputIndicesToExtract.push_back(select); - } - } - // Extract and yield the value from input tensor at - // `inputIndicesToExtract` indices. - Value result = b.create( - loc, input, inputIndicesToExtract); - b.create(loc, result); - }) + result = linalg::GenericOp::create( + rewriter, loc, outTensor.getType(), ValueRange(), outTensor, + indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + // `loopIndices` contains IV of the linalg loops which + // would be used to extract values from the input tensor + // later on. + SmallVector loopIndices; + for (size_t i = 0, e = outputRank; i < e; ++i) { + if (i < diff) + continue; + loopIndices.push_back(linalg::IndexOp::create(b, loc, i)); + } + // `inputIndicesToExtract` contains i-th linalg loop IV if + // the i-th input dimension is not 1, else it contains a + // zero index. + SmallVector inputIndicesToExtract; + for (size_t i = 0, n = inputRank; i < n; i++) { + if (inputShape[i] == 1) { + inputIndicesToExtract.push_back(zeroIndex); + } else { + Value inputDim = getDimOp(b, loc, input, i); + Value isEqual = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::eq, inputDim, oneIndex); + Value select = arith::SelectOp::create( + rewriter, loc, isEqual, zeroIndex, loopIndices[i]); + inputIndicesToExtract.push_back(select); + } + } + // Extract and yield the value from input tensor at + // `inputIndicesToExtract` indices. + Value result = tensor::ExtractOp::create( + b, loc, input, inputIndicesToExtract); + linalg::YieldOp::create(b, loc, result); + }) .getResult(0); return success(); @@ -626,8 +625,8 @@ Value torch_to_linalg::removeSizeInformation(OpBuilder &b, Location loc, auto tensorType = cast(tensor.getType()); auto rank = tensorType.getRank(); SmallVector unknownSizes(rank, kUnknownSize); - return b.create( - loc, tensorType.clone(makeShapeLLVMCompatible(unknownSizes)), tensor); + return tensor::CastOp::create( + b, loc, tensorType.clone(makeShapeLLVMCompatible(unknownSizes)), tensor); } Value torch_to_linalg::convertTensorToElementType(OpBuilder &b, Location loc, @@ -637,7 +636,7 @@ Value torch_to_linalg::convertTensorToElementType(OpBuilder &b, Location loc, ValueRange payloadArgs) { Value elem = convertScalarToDtype(builder, loc, payloadArgs[0], elementType); - builder.create(loc, elem); + linalg::YieldOp::create(builder, loc, elem); }; return torch_to_linalg::createElementwiseLinalgGeneric( b, loc, {tensor}, elementType, dtypePromoteBody); @@ -708,11 +707,11 @@ LogicalResult torch_to_linalg::permuteTensor(Operation *op, for (uint32_t i = 0; i < inputRank; i++) outputDims.push_back(getDimOp(rewriter, loc, input, dimensions[i])); - Value outVector = rewriter.create( - loc, getAsOpFoldResult(outputDims), elementType); + Value outVector = tensor::EmptyOp::create( + rewriter, loc, getAsOpFoldResult(outputDims), elementType); result = - rewriter.create(loc, input, outVector, dimensions) + linalg::TransposeOp::create(rewriter, loc, input, outVector, dimensions) ->getResult(0); return success(); } @@ -720,7 +719,7 @@ LogicalResult torch_to_linalg::permuteTensor(Operation *op, // Flips an input tensor based on the values of axis list. Value torch_to_linalg::flipTensor(PatternRewriter &rewriter, Location loc, Value input, SmallVector axis) { - Value c1 = rewriter.create(loc, rewriter.getIndexAttr(1)); + Value c1 = arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(1)); Type elementType = cast(input.getType()).getElementType(); auto selfRank = cast(input.getType()).getRank(); @@ -728,7 +727,7 @@ Value torch_to_linalg::flipTensor(PatternRewriter &rewriter, Location loc, // dims won't be used. SmallVector dims = getTensorSizes(rewriter, loc, input); for (auto flipDim : axis) - dims[flipDim] = rewriter.create(loc, dims[flipDim], c1); + dims[flipDim] = arith::SubIOp::create(rewriter, loc, dims[flipDim], c1); Value initTensor = createZeroInitTensor( rewriter, loc, getTensorSizes(rewriter, loc, input), elementType); @@ -738,22 +737,21 @@ Value torch_to_linalg::flipTensor(PatternRewriter &rewriter, Location loc, SmallVector indexingMaps( 2, AffineMap::getMultiDimIdentityMap(selfRank, rewriter.getContext())); Value flipped = - rewriter - .create( - loc, input.getType(), input, initTensor, indexingMaps, - iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - SmallVector indices; - for (auto i = 0; i < selfRank; i++) - indices.push_back(b.create(loc, i)); - for (auto flipDim : axis) { - indices[flipDim] = b.create(loc, dims[flipDim], - indices[flipDim]); - } - Value res = b.create(loc, input, indices) - .getResult(); - b.create(loc, res); - }) + linalg::GenericOp::create( + rewriter, loc, input.getType(), input, initTensor, indexingMaps, + iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + SmallVector indices; + for (auto i = 0; i < selfRank; i++) + indices.push_back(linalg::IndexOp::create(b, loc, i)); + for (auto flipDim : axis) { + indices[flipDim] = arith::SubIOp::create(b, loc, dims[flipDim], + indices[flipDim]); + } + Value res = + tensor::ExtractOp::create(b, loc, input, indices).getResult(); + linalg::YieldOp::create(b, loc, res); + }) .getResult(0); return flipped; } diff --git a/lib/Conversion/TorchToSCF/TorchToSCF.cpp b/lib/Conversion/TorchToSCF/TorchToSCF.cpp index 6f970de06592..8978a75c01a4 100644 --- a/lib/Conversion/TorchToSCF/TorchToSCF.cpp +++ b/lib/Conversion/TorchToSCF/TorchToSCF.cpp @@ -47,9 +47,9 @@ class ConvertTorchPrimIfOp : public OpConversionPattern { newResultTypes))) return rewriter.notifyMatchFailure(op, "could not convert PrimIfOp outputs"); - auto scfIf = rewriter.create(op->getLoc(), newResultTypes, - adaptor.getCondition(), - /*withElseRegion=*/true); + auto scfIf = scf::IfOp::create(rewriter, op->getLoc(), newResultTypes, + adaptor.getCondition(), + /*withElseRegion=*/true); auto inlineIfCase = [&](Region &srcRegion, Region &dstRegion) { rewriter.inlineRegionBefore(srcRegion, dstRegion, dstRegion.begin()); rewriter.eraseBlock(&dstRegion.back()); @@ -89,8 +89,8 @@ class ConvertTorchPrimLoopWhileLikeOp : public OpConversionPattern { ValueRange iterArgsInit = adaptor.getIterArgsInit(); SmallVector scfWhileOpOperands{condition}; scfWhileOpOperands.append(iterArgsInit.begin(), iterArgsInit.end()); - auto scfWhileOp = rewriter.create( - op->getLoc(), newResultTypes, scfWhileOpOperands); + auto scfWhileOp = scf::WhileOp::create(rewriter, op->getLoc(), + newResultTypes, scfWhileOpOperands); // Populate the before region of the scf.while operation. The `before` // region will have only one block and the arguments of the block must match @@ -108,8 +108,8 @@ class ConvertTorchPrimLoopWhileLikeOp : public OpConversionPattern { rewriter.setInsertionPointToEnd(beforeBlock); // Fetch the condition passed as the iter argument. Pass rest of the // arguments to the after block. - auto scfConditionOp = rewriter.create( - op.getLoc(), beforeBlock->getArgument(0), + auto scfConditionOp = scf::ConditionOp::create( + rewriter, op.getLoc(), beforeBlock->getArgument(0), beforeBlock->getArguments().drop_front()); // Populate the after region. @@ -185,8 +185,8 @@ class ConvertTorchPrimLoopWhileLikeOp : public OpConversionPattern { op, "unsupported type of the operand"); loopConditionIterArgs.push_back(arg); } - rewriter.create(scfWhileOp.getLoc(), - loopConditionIterArgs); + scf::YieldOp::create(rewriter, scfWhileOp.getLoc(), + loopConditionIterArgs); } else { operation.moveBefore(afterBlock, afterBlock->end()); @@ -221,13 +221,13 @@ class ConvertTorchPrimLoopForLikeOp : public OpConversionPattern { // Calculate the lower bound, upper bound and step indices. Currently only // lower-bound = 0 and step = 1 is supported. Location loc = op.getLoc(); - Value lowerBoundIndex = rewriter.create(loc, 0); - Value stepIndex = rewriter.create(loc, 1); - Value upperBoundIndex = rewriter.create( - loc, rewriter.getIndexType(), adaptor.getMaxTripCount()); + Value lowerBoundIndex = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value stepIndex = arith::ConstantIndexOp::create(rewriter, loc, 1); + Value upperBoundIndex = arith::IndexCastOp::create( + rewriter, loc, rewriter.getIndexType(), adaptor.getMaxTripCount()); auto scfForOp = - rewriter.create(loc, lowerBoundIndex, upperBoundIndex, - stepIndex, adaptor.getIterArgsInit()); + scf::ForOp::create(rewriter, loc, lowerBoundIndex, upperBoundIndex, + stepIndex, adaptor.getIterArgsInit()); SmallVector regionArgTypes; SmallVector regionArgLocs; @@ -249,8 +249,8 @@ class ConvertTorchPrimLoopForLikeOp : public OpConversionPattern { for (const auto &barg : enumerate(op.getRegion().front().getArguments())) { Value to = block->getArgument(barg.index()); if (isa(to.getType())) - to = - rewriter.create(loc, rewriter.getI64Type(), to); + to = arith::IndexCastOp::create(rewriter, loc, rewriter.getI64Type(), + to); Type targetType = to.getType(); Value torchArg = to; @@ -298,7 +298,8 @@ class ConvertTorchPrimLoopForLikeOp : public OpConversionPattern { op, "unsupported type of the operand"); loopConditionIterArgs.push_back(arg); } - rewriter.create(scfForOp.getLoc(), loopConditionIterArgs); + scf::YieldOp::create(rewriter, scfForOp.getLoc(), + loopConditionIterArgs); } else { operation.moveBefore(block, block->end()); } diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 09f3b12de494..a22e6658a2ac 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -80,8 +80,8 @@ static FailureOr getMaxValueOfDtype(Operation *op, Type elementType, constType, APFloat::getInf(cast(elementType).getFloatSemantics(), /*negative=*/false)); - return rewriter - .create(op->getLoc(), constType, constAttr) + return stablehlo::ConstantOp::create(rewriter, op->getLoc(), constType, + constAttr) .getResult(); } if (isa(elementType)) { @@ -94,8 +94,8 @@ static FailureOr getMaxValueOfDtype(Operation *op, Type elementType, constAttr = SplatElementsAttr::get( constType, APInt::getSignedMaxValue(integerType.getWidth())); } - return rewriter - .create(op->getLoc(), constType, constAttr) + return stablehlo::ConstantOp::create(rewriter, op->getLoc(), constType, + constAttr) .getResult(); } return failure(); @@ -109,8 +109,8 @@ static FailureOr getMinValueOfDtype(Operation *op, Type elementType, constType, APFloat::getInf(cast(elementType).getFloatSemantics(), /*negative=*/true)); - return rewriter - .create(op->getLoc(), constType, constAttr) + return stablehlo::ConstantOp::create(rewriter, op->getLoc(), constType, + constAttr) .getResult(); } if (isa(elementType)) { @@ -123,8 +123,8 @@ static FailureOr getMinValueOfDtype(Operation *op, Type elementType, constAttr = SplatElementsAttr::get( constType, APInt::getSignedMinValue(integerType.getWidth())); } - return rewriter - .create(op->getLoc(), constType, constAttr) + return stablehlo::ConstantOp::create(rewriter, op->getLoc(), constType, + constAttr) .getResult(); } return failure(); @@ -287,22 +287,22 @@ class ConvertAtenTensorToScalarLikeOp : public OpConversionPattern { Type inputDtype = cast(op.getA().getType()).getDtype(); Value constantOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + arith::ConstantOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); for (int64_t i = 0; i < inputRank; i++) checkDimEqualHelper(rewriter, loc, inputSizes[i], constantOne); // handle unsigned interger if (inputType.getElementType().isUnsignedInteger()) { - input = rewriter.create( - loc, input, + input = stablehlo::ConvertOp::create( + rewriter, loc, input, rewriter.getIntegerType( inputType.getElementType().getIntOrFloatBitWidth())); } Value constantZero = - rewriter.create(loc, rewriter.getIndexAttr(0)); + arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(0)); SmallVector indices(inputRank, constantZero); - Value result = rewriter.create(loc, input, indices); + Value result = tensor::ExtractOp::create(rewriter, loc, input, indices); Type resultType = this->getTypeConverter()->convertType(op->getResult(0).getType()); rewriter.replaceOp( @@ -390,8 +390,8 @@ class ConvertAtenAddSubOp : public OpConversionPattern { Value alpha = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getAlpha(), outElemTy); DenseI64ArrayAttr bcastDimensions; - rhs = rewriter.create(op->getLoc(), rhs, alpha, - bcastDimensions); + rhs = chlo::BroadcastMulOp::create(rewriter, op->getLoc(), rhs, alpha, + bcastDimensions); } DenseI64ArrayAttr bcastDimensions; @@ -443,7 +443,7 @@ class ConvertAtenMulDivOp : public OpConversionPattern { rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outElemTy); auto loc = op.getLoc(); Value result = - rewriter.create(loc, outType, lhs, rhs, bcastDimensions); + ChloOpT::create(rewriter, loc, outType, lhs, rhs, bcastDimensions); if constexpr (!std::is_same() && !std::is_same()) { @@ -467,26 +467,27 @@ class ConvertAtenMulDivOp : public OpConversionPattern { if (roundingMode == "trunc" && isa(outElemTy)) { // "trunc" - rounds the results of the division towards zero. Equivalent // to C-style integer division. - auto sign = rewriter.create(loc, result); - auto abs = rewriter.create(loc, result); - auto floor = rewriter.create(loc, abs); - result = rewriter.create(loc, sign, floor).getResult(); + auto sign = stablehlo::SignOp::create(rewriter, loc, result); + auto abs = stablehlo::AbsOp::create(rewriter, loc, result); + auto floor = stablehlo::FloorOp::create(rewriter, loc, abs); + result = stablehlo::MulOp::create(rewriter, loc, sign, floor).getResult(); } if (roundingMode == "floor") { // "floor" - rounds the results of the division down. Equivalent to // floor division in Python (the // operator) if (isa(outElemTy)) - result = rewriter.create(loc, result).getResult(); + result = stablehlo::FloorOp::create(rewriter, loc, result).getResult(); else if (!outElemTy.isUnsignedInteger()) { Type defaultIntToFloatType = rewriter.getF64Type(); lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, defaultIntToFloatType); rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, defaultIntToFloatType); - result = rewriter.create( - loc, outType.cloneWith(outType.getShape(), defaultIntToFloatType), - lhs, rhs, bcastDimensions); - result = rewriter.create(loc, result).getResult(); + result = ChloOpT::create( + rewriter, loc, + outType.cloneWith(outType.getShape(), defaultIntToFloatType), lhs, + rhs, bcastDimensions); + result = stablehlo::FloorOp::create(rewriter, loc, result).getResult(); result = hlo::promoteType(rewriter, op.getLoc(), result, outType.getElementType()); } @@ -721,18 +722,18 @@ LogicalResult ConvertAtenOp::matchAndRewrite( dimInt = toPositiveDim(dimInt, selfType.getRank()); if (!isValidDim(dimInt, selfType.getRank())) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); - dim = rewriter.create(op.getLoc(), dimInt); + dim = arith::ConstantIndexOp::create(rewriter, op.getLoc(), dimInt); } else { - Value inputRank = rewriter.create( - op.getLoc(), rewriter.getI64IntegerAttr(selfType.getRank())); + Value inputRank = arith::ConstantOp::create( + rewriter, op.getLoc(), rewriter.getI64IntegerAttr(selfType.getRank())); dim = toPositiveDimDynamic(rewriter, op.getLoc(), adaptor.getDim(), inputRank); - dim = rewriter.create(op.getLoc(), - rewriter.getIndexType(), dim); + dim = arith::IndexCastOp::create(rewriter, op.getLoc(), + rewriter.getIndexType(), dim); } - auto dimSize = rewriter.create( - op.getLoc(), rewriter.getIndexType(), adaptor.getSelf(), dim); + auto dimSize = tensor::DimOp::create( + rewriter, op.getLoc(), rewriter.getIndexType(), adaptor.getSelf(), dim); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), dimSize); @@ -800,31 +801,31 @@ LogicalResult ConvertAtenOp::matchAndRewrite( int64_t dInt; if (i >= leadingRank && matchPattern(dValue, m_TorchConstantInt(&dInt)) && dInt == -1) { - newD = rewriter.create(op->getLoc(), self, - i - leadingRank); + newD = mlir::tensor::DimOp::create(rewriter, op->getLoc(), self, + i - leadingRank); } else { - dValue = rewriter.create(op->getLoc(), - dValue); - newD = rewriter.create( - op->getLoc(), rewriter.getIndexType(), dValue); + dValue = torch::TorchConversion::ToI64Op::create(rewriter, op->getLoc(), + dValue); + newD = mlir::arith::IndexCastOp::create(rewriter, op->getLoc(), + rewriter.getIndexType(), dValue); } bcastShapeVec.push_back(newD); } if (options.dimSizeIndexBits == 32) { for (auto &dsize : bcastShapeVec) { - auto dsizeI64 = rewriter.create( - op->getLoc(), rewriter.getI64Type(), dsize); - dsize = rewriter.create(op->getLoc(), - rewriter.getI32Type(), dsizeI64); + auto dsizeI64 = mlir::arith::IndexCastOp::create( + rewriter, op->getLoc(), rewriter.getI64Type(), dsize); + dsize = arith::TruncIOp::create(rewriter, op->getLoc(), + rewriter.getI32Type(), dsizeI64); } } if (bcastShapeVec.size() == 0) { rewriter.replaceOpWithNewOp(op, outType, self); } else { - Value bcastShapeTensor = rewriter.create( - op->getLoc(), ValueRange{bcastShapeVec}); + Value bcastShapeTensor = mlir::tensor::FromElementsOp::create( + rewriter, op->getLoc(), ValueRange{bcastShapeVec}); auto dimensionNumbers = llvm::to_vector<4>(llvm::seq(leadingRank, totalRank)); rewriter.replaceOpWithNewOp( @@ -1001,7 +1002,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Type inputDtype = cast(op.getA().getType()).getDtype(); Type resultType = this->getTypeConverter()->convertType(op->getResult(0).getType()); - auto result = rewriter.create(loc, adaptor.getA()); + auto result = tensor::ExtractOp::create(rewriter, loc, adaptor.getA()); rewriter.replaceOp( op, convertScalarToDtype(rewriter, loc, result, resultType, inputDtype)); @@ -1078,24 +1079,25 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value t = hlo::getConstantLike(rewriter, loc, 0.044715, input); // x * 0.5 - auto inputMulHalf = rewriter.create(loc, input, half); + auto inputMulHalf = stablehlo::MulOp::create(rewriter, loc, input, half); if (approximate == "none") { - auto rsqrtTwo = rewriter.create(loc, two); - auto erfElement = rewriter.create(loc, input, rsqrtTwo); - auto erf = rewriter.create(loc, erfElement); - auto erfAdd = rewriter.create(loc, erf, one); + auto rsqrtTwo = stablehlo::RsqrtOp::create(rewriter, loc, two); + auto erfElement = stablehlo::MulOp::create(rewriter, loc, input, rsqrtTwo); + auto erf = chlo::ErfOp::create(rewriter, loc, erfElement); + auto erfAdd = stablehlo::AddOp::create(rewriter, loc, erf, one); rewriter.replaceOpWithNewOp(op, erfAdd, inputMulHalf); return success(); } else { - auto sqrtTwoPi = rewriter.create(loc, twoDivPi); + auto sqrtTwoPi = stablehlo::SqrtOp::create(rewriter, loc, twoDivPi); // x^3 - auto powThree = rewriter.create(loc, input, three); + auto powThree = stablehlo::PowOp::create(rewriter, loc, input, three); // x + 0.044715 * x^3 - auto add = rewriter.create( - loc, input, rewriter.create(loc, t, powThree)); - auto tanh = rewriter.create( - loc, rewriter.create(loc, sqrtTwoPi, add)); - auto tanhAdd = rewriter.create(loc, tanh, one); + auto add = stablehlo::AddOp::create( + rewriter, loc, input, + stablehlo::MulOp::create(rewriter, loc, t, powThree)); + auto tanh = stablehlo::TanhOp::create( + rewriter, loc, stablehlo::MulOp::create(rewriter, loc, sqrtTwoPi, add)); + auto tanhAdd = stablehlo::AddOp::create(rewriter, loc, tanh, one); rewriter.replaceOpWithNewOp(op, tanhAdd, inputMulHalf); return success(); } @@ -1116,8 +1118,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( hlo::promoteType(rewriter, op.getLoc(), input, outTy.getElementType()); auto two = hlo::getConstantLike(rewriter, op.getLoc(), 2.0, input); - auto log2Op = rewriter.create(op.getLoc(), two); - auto logInputOp = rewriter.create(op.getLoc(), input); + auto log2Op = stablehlo::LogOp::create(rewriter, op.getLoc(), two); + auto logInputOp = stablehlo::LogOp::create(rewriter, op.getLoc(), input); rewriter.replaceOpWithNewOp(op, outTy, logInputOp, log2Op); return success(); @@ -1139,8 +1141,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( hlo::promoteType(rewriter, op.getLoc(), input, outTy.getElementType()); auto ten = hlo::getConstantLike(rewriter, op.getLoc(), 10.0, input); - auto log10Op = rewriter.create(op.getLoc(), ten); - auto logInputOp = rewriter.create(op.getLoc(), input); + auto log10Op = stablehlo::LogOp::create(rewriter, op.getLoc(), ten); + auto logInputOp = stablehlo::LogOp::create(rewriter, op.getLoc(), input); rewriter.replaceOpWithNewOp(op, outTy, logInputOp, log10Op); return success(); @@ -1172,17 +1174,17 @@ LogicalResult ConvertAtenOp::matchAndRewrite( selfTy.getElementType()); Value oneEpsTensor = hlo::getConstantLike(rewriter, loc, 1.0, epsTensor); auto max = - rewriter.create(loc, oneEpsTensor, epsTensor); - newSelf = rewriter.create(loc, epsTensor, self, max); + stablehlo::SubtractOp::create(rewriter, loc, oneEpsTensor, epsTensor); + newSelf = stablehlo::ClampOp::create(rewriter, loc, epsTensor, self, max); } else { newSelf = self; } Value one = hlo::getConstantLike(rewriter, loc, 1.0, self); - Value zi1 = rewriter.create(loc, one, newSelf); - Value newZi = rewriter.create(loc, newSelf, zi1); + Value zi1 = stablehlo::SubtractOp::create(rewriter, loc, one, newSelf); + Value newZi = stablehlo::DivOp::create(rewriter, loc, newSelf, zi1); - Value log = rewriter.create(loc, outTy, newZi); + Value log = stablehlo::LogOp::create(rewriter, loc, outTy, newZi); rewriter.replaceOp(op, log); @@ -1229,17 +1231,17 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto inputElemTy = cast(inputTy.getElementType()); Value channelDim = - rewriter.create(op->getLoc(), input, feature_index); + tensor::DimOp::create(rewriter, op->getLoc(), input, feature_index); if (options.dimSizeIndexBits == 32) { - auto channelDimI64 = rewriter.create( - op->getLoc(), rewriter.getI64Type(), channelDim); - channelDim = rewriter.create( - op->getLoc(), rewriter.getI32Type(), channelDimI64); + auto channelDimI64 = mlir::arith::IndexCastOp::create( + rewriter, op->getLoc(), rewriter.getI64Type(), channelDim); + channelDim = arith::TruncIOp::create(rewriter, op->getLoc(), + rewriter.getI32Type(), channelDimI64); } - Value channelShape = rewriter.create( - op->getLoc(), ValueRange{channelDim}); + Value channelShape = tensor::FromElementsOp::create(rewriter, op->getLoc(), + ValueRange{channelDim}); if (failed(checkNotNone(rewriter, op, weight))) { weight = hlo::getConstantOfShape( rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1), @@ -1316,23 +1318,21 @@ LogicalResult ConvertAtenOp::matchAndRewrite( input = hlo::promoteType(rewriter, op.getLoc(), input, computeType); weight = hlo::promoteType(rewriter, op.getLoc(), weight, computeType); bias = hlo::promoteType(rewriter, op.getLoc(), bias, computeType); - auto batchNormTrainingResult = - rewriter.create( - op.getLoc(), - RankedTensorType::get(inputTy.getShape(), computeType), - RankedTensorType::get(weightTy.getShape(), computeType), - RankedTensorType::get(weightTy.getShape(), computeType), input, - weight, bias, rewriter.getF32FloatAttr(eps), - rewriter.getI64IntegerAttr(feature_index)); + auto batchNormTrainingResult = stablehlo::BatchNormTrainingOp::create( + rewriter, op.getLoc(), + RankedTensorType::get(inputTy.getShape(), computeType), + RankedTensorType::get(weightTy.getShape(), computeType), + RankedTensorType::get(weightTy.getShape(), computeType), input, + weight, bias, rewriter.getF32FloatAttr(eps), + rewriter.getI64IntegerAttr(feature_index)); output = hlo::promoteType(rewriter, op.getLoc(), batchNormTrainingResult.getResult(0), outputTy.getElementType()); } else { - auto batchNormTrainingResult = - rewriter.create( - op.getLoc(), outputTy, weightTy, weightTy, input, weight, bias, - rewriter.getF32FloatAttr(eps), - rewriter.getI64IntegerAttr(feature_index)); + auto batchNormTrainingResult = stablehlo::BatchNormTrainingOp::create( + rewriter, op.getLoc(), outputTy, weightTy, weightTy, input, weight, + bias, rewriter.getF32FloatAttr(eps), + rewriter.getI64IntegerAttr(feature_index)); output = batchNormTrainingResult.getResult(0); } rewriter.replaceOp(op, output); @@ -1347,7 +1347,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Feature counts must match among operands of // stablehlo::BatchNormInferenceOp. Value inputCasted = - rewriter.create(op.getLoc(), castTy, input); + tensor::CastOp::create(rewriter, op.getLoc(), castTy, input); Value output; // supported mixed types, like input type is fp16 and weight type is fp32. @@ -1364,17 +1364,17 @@ LogicalResult ConvertAtenOp::matchAndRewrite( hlo::promoteType(rewriter, op.getLoc(), runningMean, computeType); runningVar = hlo::promoteType(rewriter, op.getLoc(), runningVar, computeType); - Value bnResult = rewriter.create( - op.getLoc(), RankedTensorType::get(inputTy.getShape(), computeType), - input, weight, bias, runningMean, runningVar, - rewriter.getF32FloatAttr(eps), + Value bnResult = stablehlo::BatchNormInferenceOp::create( + rewriter, op.getLoc(), + RankedTensorType::get(inputTy.getShape(), computeType), input, weight, + bias, runningMean, runningVar, rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(feature_index)); output = hlo::promoteType(rewriter, op.getLoc(), bnResult, outputTy.getElementType()); } else { - output = rewriter.create( - op.getLoc(), inputCasted.getType(), inputCasted, weight, bias, - runningMean, runningVar, + output = stablehlo::BatchNormInferenceOp::create( + rewriter, op.getLoc(), inputCasted.getType(), inputCasted, weight, + bias, runningMean, runningVar, // 'epsilon' must satisfy constraint: 32-bit float attribute. rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(feature_index)); @@ -1459,8 +1459,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( meanOrVarStablehloOutShape, inputTy.getElementType()); // Reshape input - auto stablehloInput = rewriter.create( - op->getLoc(), stablehloBatchNormOutTy, input, + auto stablehloInput = stablehlo::DynamicReshapeOp::create( + rewriter, op->getLoc(), stablehloBatchNormOutTy, input, hlo::getConstTensor(rewriter, op, llvm::ArrayRef(inputFlattenShape), {static_cast(inputFlattenShape.size())}) .value()); @@ -1478,18 +1478,17 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto oneOrZeroConstType = RankedTensorType::get({numFeatureDimSize}, inputTy.getElementType()); - Value scale = rewriter.create( - op->getLoc(), oneOrZeroConstType, + Value scale = stablehlo::ConstantOp::create( + rewriter, op->getLoc(), oneOrZeroConstType, DenseElementsAttr::get(oneOrZeroConstType, oneConstVec)); - Value offset = rewriter.create( - op->getLoc(), oneOrZeroConstType, + Value offset = stablehlo::ConstantOp::create( + rewriter, op->getLoc(), oneOrZeroConstType, DenseElementsAttr::get(oneOrZeroConstType, zeroConstVec)); - auto batchNormTrainingResult = - rewriter.create( - op->getLoc(), stablehloBatchNormOutTy, - stablehloBathNormOutMeanOrVarTy, stablehloBathNormOutMeanOrVarTy, - stablehloInput, scale, offset, rewriter.getF32FloatAttr(eps), - rewriter.getI64IntegerAttr(1)); + auto batchNormTrainingResult = stablehlo::BatchNormTrainingOp::create( + rewriter, op->getLoc(), stablehloBatchNormOutTy, + stablehloBathNormOutMeanOrVarTy, stablehloBathNormOutMeanOrVarTy, + stablehloInput, scale, offset, rewriter.getF32FloatAttr(eps), + rewriter.getI64IntegerAttr(1)); // Reshape back auto outputTy = @@ -1497,19 +1496,21 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto outputMeanOrVarTy = cast(getTypeConverter()->convertType(op.getType(1))); - auto output = rewriter.create( - op->getLoc(), outputTy, batchNormTrainingResult.getResult(0), + auto output = stablehlo::DynamicReshapeOp::create( + rewriter, op->getLoc(), outputTy, batchNormTrainingResult.getResult(0), hlo::getConstTensor(rewriter, op, outputTy.getShape(), {static_cast(outputTy.getShape().size())}) .value()); - auto mean = rewriter.create( - op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(1), + auto mean = stablehlo::DynamicReshapeOp::create( + rewriter, op->getLoc(), outputMeanOrVarTy, + batchNormTrainingResult.getResult(1), hlo::getConstTensor( rewriter, op, outputMeanOrVarTy.getShape(), {static_cast(outputMeanOrVarTy.getShape().size())}) .value()); - auto var = rewriter.create( - op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(2), + auto var = stablehlo::DynamicReshapeOp::create( + rewriter, op->getLoc(), outputMeanOrVarTy, + batchNormTrainingResult.getResult(2), hlo::getConstTensor( rewriter, op, outputMeanOrVarTy.getShape(), {static_cast(outputMeanOrVarTy.getShape().size())}) @@ -1521,9 +1522,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto bcastedBias = hlo::promoteAndBroadcast(rewriter, bias, outputTy, std::nullopt); auto outputMulWeight = - rewriter.create(op->getLoc(), output, bcastedWeight); - auto finalOuput = rewriter.create( - op->getLoc(), outputMulWeight, bcastedBias); + stablehlo::MulOp::create(rewriter, op->getLoc(), output, bcastedWeight); + auto finalOuput = stablehlo::AddOp::create(rewriter, op->getLoc(), + outputMulWeight, bcastedBias); rewriter.replaceOp(op, {finalOuput, mean, var}); return success(); } @@ -1573,12 +1574,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Type intType = rewriter.getIntegerType(options.dimSizeIndexBits); auto loc = op->getLoc(); - Value numel = rewriter.create( - loc, rewriter.getIntegerAttr(intType, 1)); + Value numel = arith::ConstantOp::create(rewriter, loc, + rewriter.getIntegerAttr(intType, 1)); for (size_t d = 0; d < rank; ++d) { - Value dimSize = rewriter.create( - loc, intType, rewriter.create(loc, self, d)); - numel = rewriter.create(loc, numel, dimSize); + Value dimSize = arith::IndexCastOp::create( + rewriter, loc, intType, tensor::DimOp::create(rewriter, loc, self, d)); + numel = arith::MulIOp::create(rewriter, loc, numel, dimSize); } auto outTy = getTypeConverter()->convertType(op.getType()); @@ -1699,24 +1700,26 @@ LogicalResult ConvertAtenOp::matchAndRewrite( hlo::scalarToStablehloTensor(rewriter, op, adaptor.getStep(), dtype); // Get length of the 1-d output tensor - Value subOut = rewriter.create(loc, end, start); + Value subOut = stablehlo::SubtractOp::create(rewriter, loc, end, start); // promote div to f64 Type divType = RankedTensorType::get({}, rewriter.getF64Type()); - Value divOut = rewriter.create( - loc, rewriter.create(loc, divType, subOut), - rewriter.create(loc, divType, step)); + Value divOut = stablehlo::DivOp::create( + rewriter, loc, + stablehlo::ConvertOp::create(rewriter, loc, divType, subOut), + stablehlo::ConvertOp::create(rewriter, loc, divType, step)); // ceil to i64 - Value resultLength = rewriter.create( - loc, RankedTensorType::get({}, rewriter.getI64Type()), - rewriter.create(loc, divOut)); - resultLength = rewriter.create( - loc, RankedTensorType::get({1}, rewriter.getI64Type()), resultLength); + Value resultLength = stablehlo::ConvertOp::create( + rewriter, loc, RankedTensorType::get({}, rewriter.getI64Type()), + stablehlo::CeilOp::create(rewriter, loc, divOut)); + resultLength = stablehlo::ReshapeOp::create( + rewriter, loc, RankedTensorType::get({1}, rewriter.getI64Type()), + resultLength); Value window = - rewriter.create(loc, outType, resultLength, 0); + stablehlo::DynamicIotaOp::create(rewriter, loc, outType, resultLength, 0); DenseI64ArrayAttr broadcastDimensions; - Value mulOut = rewriter.create(loc, window, step, - broadcastDimensions); + Value mulOut = chlo::BroadcastMulOp::create(rewriter, loc, window, step, + broadcastDimensions); rewriter.replaceOpWithNewOp(op, mulOut, start, broadcastDimensions); return success(); @@ -1797,10 +1800,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( SmallVector strides(rank, 1); startIndices[dim] = 1; limitIndices[dim] = padInts[0] + 1; - left = rewriter.create(loc, self, startIndices, - limitIndices, strides); - left = rewriter.create(loc, left, - ArrayRef({dim})); + left = stablehlo::SliceOp::create(rewriter, loc, self, startIndices, + limitIndices, strides); + left = stablehlo::ReverseOp::create(rewriter, loc, left, + ArrayRef({dim})); } Value right; { @@ -1810,13 +1813,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( SmallVector strides(rank, 1); startIndices[dim] = selfTy.getDimSize(dim) - 1 - padInts[1]; limitIndices[dim] = selfTy.getDimSize(dim) - 1; - right = rewriter.create(loc, self, startIndices, - limitIndices, strides); - right = rewriter.create(loc, right, - ArrayRef({dim})); + right = stablehlo::SliceOp::create(rewriter, loc, self, startIndices, + limitIndices, strides); + right = stablehlo::ReverseOp::create(rewriter, loc, right, + ArrayRef({dim})); } - Value result = rewriter.create( - loc, ValueRange{left, self, right}, dim); + Value result = stablehlo::ConcatenateOp::create( + rewriter, loc, ValueRange{left, self, right}, dim); rewriter.replaceOp(op, result); return success(); } @@ -1849,24 +1852,24 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Compute Value kBeta0 = - rewriter.create(loc, outType, kAlpha, cstAlpha0); - Value kBeta = rewriter.create(loc, outType, kBeta0, half); - Value erfArg = rewriter.create(loc, outType, kAlpha, - adaptor.getSelf()); - Value erf = rewriter.create(loc, outType, erfArg); - Value erfAdd = rewriter.create(loc, outType, erf, one); - Value cdf = rewriter.create(loc, outType, erfAdd, half); - Value inputSquared = rewriter.create( - loc, outType, adaptor.getSelf(), adaptor.getSelf()); + stablehlo::MulOp::create(rewriter, loc, outType, kAlpha, cstAlpha0); + Value kBeta = stablehlo::MulOp::create(rewriter, loc, outType, kBeta0, half); + Value erfArg = stablehlo::MulOp::create(rewriter, loc, outType, kAlpha, + adaptor.getSelf()); + Value erf = mlir::chlo::ErfOp::create(rewriter, loc, outType, erfArg); + Value erfAdd = stablehlo::AddOp::create(rewriter, loc, outType, erf, one); + Value cdf = stablehlo::MulOp::create(rewriter, loc, outType, erfAdd, half); + Value inputSquared = stablehlo::MulOp::create( + rewriter, loc, outType, adaptor.getSelf(), adaptor.getSelf()); Value negHalfInputSquared = - rewriter.create(loc, outType, inputSquared, negHalf); + stablehlo::MulOp::create(rewriter, loc, outType, inputSquared, negHalf); Value expRes = - rewriter.create(loc, outType, negHalfInputSquared); - Value pdf = rewriter.create(loc, outType, kBeta, expRes); + stablehlo::ExpOp::create(rewriter, loc, outType, negHalfInputSquared); + Value pdf = stablehlo::MulOp::create(rewriter, loc, outType, kBeta, expRes); Value pdfTimesInput = - rewriter.create(loc, outType, pdf, adaptor.getSelf()); + stablehlo::MulOp::create(rewriter, loc, outType, pdf, adaptor.getSelf()); Value pdfTimesInputAddCdf = - rewriter.create(loc, outType, pdfTimesInput, cdf); + stablehlo::AddOp::create(rewriter, loc, outType, pdfTimesInput, cdf); rewriter.replaceOpWithNewOp( op, outType, adaptor.getGradOutput(), pdfTimesInputAddCdf); return success(); @@ -1956,8 +1959,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } // Create an uninitialized tensor of `resultSize` shape. - Value initTensor = rewriter.create( - loc, getAsOpFoldResult(resultSizeIndex), resultElementType); + Value initTensor = tensor::EmptyOp::create( + rewriter, loc, getAsOpFoldResult(resultSizeIndex), resultElementType); rewriter.replaceOpWithNewOp(op, resultType, initTensor); return success(); } @@ -1996,9 +1999,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value scalarTensor = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getValue(), dtype); Value shapeTensor = - rewriter.create(op->getLoc(), adaptor.getSelf()); - Value bcastScalar = rewriter.create( - op->getLoc(), outType, scalarTensor, shapeTensor, + shape::ShapeOfOp::create(rewriter, op->getLoc(), adaptor.getSelf()); + Value bcastScalar = stablehlo::DynamicBroadcastInDimOp::create( + rewriter, op->getLoc(), outType, scalarTensor, shapeTensor, rewriter.getDenseI64ArrayAttr({})); rewriter.replaceOp(op, bcastScalar); return success(); @@ -2046,16 +2049,16 @@ LogicalResult ConvertAtenOp::matchAndRewrite( resultType.getElementType()); stablehlo::MulOp mul; - auto div = rewriter.create(loc, lhs, rhs); + auto div = stablehlo::DivOp::create(rewriter, loc, lhs, rhs); if (isa(resultType.getElementType())) { // rounding mode is trunc - auto sign = rewriter.create(loc, div); - auto abs = rewriter.create(loc, div); - auto floor = rewriter.create(loc, abs); - auto trunc = rewriter.create(loc, sign, floor); - mul = rewriter.create(loc, trunc, rhs); + auto sign = stablehlo::SignOp::create(rewriter, loc, div); + auto abs = stablehlo::AbsOp::create(rewriter, loc, div); + auto floor = stablehlo::FloorOp::create(rewriter, loc, abs); + auto trunc = stablehlo::MulOp::create(rewriter, loc, sign, floor); + mul = stablehlo::MulOp::create(rewriter, loc, trunc, rhs); } else { - mul = rewriter.create(loc, div, rhs); + mul = stablehlo::MulOp::create(rewriter, loc, div, rhs); } rewriter.replaceOpWithNewOp(op, lhs, mul); return success(); @@ -2111,34 +2114,34 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto iotaTy = RankedTensorType::get( {selfShape[selfRank - 2], selfShape[selfRank - 1]}, iotaElementTy); Value colIdxTensor = - rewriter.create(loc, iotaTy, 1).getResult(); + stablehlo::IotaOp::create(rewriter, loc, iotaTy, 1).getResult(); Value rowIdxTensor = - rewriter.create(loc, iotaTy, 0).getResult(); + stablehlo::IotaOp::create(rewriter, loc, iotaTy, 0).getResult(); Value diagonal = adaptor.getDiagonal(); Value diagonalTensor = - rewriter.create(loc, diagonal).getResult(); + tensor::FromElementsOp::create(rewriter, loc, diagonal).getResult(); auto bcastDimensions = rewriter.getDenseI64ArrayAttr({1}); - Value shiftedRowIdxTensor = rewriter.create( - loc, rowIdxTensor, diagonalTensor, bcastDimensions); + Value shiftedRowIdxTensor = chlo::BroadcastAddOp::create( + rewriter, loc, rowIdxTensor, diagonalTensor, bcastDimensions); auto cmpDirectionAttr = stablehlo::ComparisonDirectionAttr::get( rewriter.getContext(), stablehlo::ComparisonDirection::LE); auto cmpTypeAttr = stablehlo::ComparisonTypeAttr::get( rewriter.getContext(), stablehlo::ComparisonType::SIGNED); auto cmpTy = iotaTy.clone(rewriter.getI1Type()); - Value cmpRes = rewriter.create( - loc, cmpTy, colIdxTensor, shiftedRowIdxTensor, cmpDirectionAttr, - cmpTypeAttr); + Value cmpRes = stablehlo::CompareOp::create(rewriter, loc, cmpTy, + colIdxTensor, shiftedRowIdxTensor, + cmpDirectionAttr, cmpTypeAttr); auto resTy = cast(getTypeConverter()->convertType(op.getType())); auto bcastTy = resTy.clone(rewriter.getI1Type()); auto bcastAttr = rewriter.getDenseI64ArrayAttr({selfRank - 2, selfRank - 1}); - Value bcastedCmpRes = rewriter.create( - loc, bcastTy, cmpRes, bcastAttr); + Value bcastedCmpRes = stablehlo::BroadcastInDimOp::create( + rewriter, loc, bcastTy, cmpRes, bcastAttr); auto resElemTy = resTy.getElementType(); Value zeroTensor; @@ -2146,12 +2149,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto constAttr = SplatElementsAttr::get( resTy, llvm::APFloat::getZero( cast(resElemTy).getFloatSemantics(), false)); - zeroTensor = rewriter.create(loc, resTy, constAttr); + zeroTensor = stablehlo::ConstantOp::create(rewriter, loc, resTy, constAttr); } else if (isa(resElemTy)) { auto constAttr = SplatElementsAttr::get( resTy, llvm::APInt::getZero(cast(resElemTy).getWidth())); - zeroTensor = rewriter.create(loc, resTy, constAttr); + zeroTensor = stablehlo::ConstantOp::create(rewriter, loc, resTy, constAttr); } else { return op.emitError("element type is not float or integer"); } diff --git a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp index c7a67abebab5..8ebb7050b124 100644 --- a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp +++ b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp @@ -42,14 +42,14 @@ static Value createInitialValueForGatherScatterOp(Operation *op, constType, {APFloat::getZero( cast(elementTy).getFloatSemantics(), /*negative=*/false)}); - return rewriter.create(op->getLoc(), constType, - constAttr); + return stablehlo::ConstantOp::create(rewriter, op->getLoc(), constType, + constAttr); } else if (isa(elementTy) && elementTy.getIntOrFloatBitWidth() != 8) { auto constAttr = DenseElementsAttr::get( constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())}); - return rewriter.create(op->getLoc(), constType, - constAttr); + return stablehlo::ConstantOp::create(rewriter, op->getLoc(), constType, + constAttr); } } @@ -63,8 +63,8 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op, size_t dimSizeIndexBits) { auto loc = op->getLoc(); Type intType = rewriter.getIntegerType(dimSizeIndexBits); - Value one = rewriter.create( - loc, rewriter.getIntegerAttr(intType, 1)); + Value one = arith::ConstantOp::create(rewriter, loc, + rewriter.getIntegerAttr(intType, 1)); // sliceSizes auto inputRankTy = dyn_cast(input.getType()); @@ -75,12 +75,13 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op, if (r == axis) { sliceSizes.push_back(one); } else { - sliceSizes.push_back(rewriter.create( - loc, intType, rewriter.create(loc, input, r))); + sliceSizes.push_back(arith::IndexCastOp::create( + rewriter, loc, intType, + tensor::DimOp::create(rewriter, loc, input, r))); } } auto sliceSizesTensor = - rewriter.create(loc, sliceSizes); + tensor::FromElementsOp::create(rewriter, loc, sliceSizes); // offsetDims SmallVector offsetDims; @@ -123,9 +124,8 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op, // create output tensor type auto outputTy = RankedTensorType::get(outputShape, inputRankTy.getElementType()); - return rewriter - .create(loc, outputTy, input, indices, - sliceSizesTensor, dimsAttr) + return stablehlo::DynamicGatherOp::create(rewriter, loc, outputTy, input, + indices, sliceSizesTensor, dimsAttr) .getResult(); } @@ -139,8 +139,8 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, auto input = adaptor.getSelf(); RankedTensorType inputType = cast(input.getType()); - Value zero = rewriter.create(loc, 0); - Value one = rewriter.create(loc, 1); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value one = arith::ConstantIndexOp::create(rewriter, loc, 1); int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) @@ -176,24 +176,25 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, dimSize, dimSize); // end >= start ? end : start - Value endSgeStart = rewriter.create( - loc, arith::CmpIPredicate::sge, end, start); - end = rewriter.create(loc, endSgeStart, end, start); - Value stepIndex = rewriter.create(loc, step); + Value endSgeStart = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sge, end, start); + end = arith::SelectOp::create(rewriter, loc, endSgeStart, end, start); + Value stepIndex = arith::ConstantIndexOp::create(rewriter, loc, step); // Slice logic: resultSize = floordiv(end - start + step - 1, step) resultShape = getTensorSizes(rewriter, loc, input); - Value len = rewriter.create(loc, end, start); - Value resultSize = rewriter.create(loc, len, stepIndex); - resultSize = rewriter.create(loc, resultSize, one); - resultSize = rewriter.create(loc, resultSize, stepIndex); + Value len = arith::SubIOp::create(rewriter, loc, end, start); + Value resultSize = arith::AddIOp::create(rewriter, loc, len, stepIndex); + resultSize = arith::SubIOp::create(rewriter, loc, resultSize, one); + resultSize = + arith::FloorDivSIOp::create(rewriter, loc, resultSize, stepIndex); resultShape[dim] = resultSize; strides.resize(inputType.getRank(), one); offsets.resize(inputType.getRank(), zero); offsets[dim] = start; - strides[dim] = rewriter.create(loc, strides[dim], stepIndex); + strides[dim] = arith::MulIOp::create(rewriter, loc, strides[dim], stepIndex); return success(); } } // namespace @@ -259,22 +260,23 @@ FailureOr broadcastAndConcatIndices(Operation *op, if (allIndexStaticShape) { bcastVal = hlo::promoteAndBroadcast(rewriter, indexTensor, bcastIndexType, std::nullopt); - bcastVal = rewriter.create(op->getLoc(), - reshapeType, bcastVal); + bcastVal = stablehlo::ReshapeOp::create(rewriter, op->getLoc(), + reshapeType, bcastVal); } else { bcastVal = hlo::promoteAndBroadcast(rewriter, indexTensor, bcastIndexType, bcastSizeTensor); auto bcastValShapeTensorVec = *hlo::getDimSizesOfTensor(rewriter, op, bcastVal, dimSizeIndexBits); - bcastValShapeTensorVec.push_back(rewriter.create( - op->getLoc(), rewriter.getIntegerAttr( - rewriter.getIntegerType(dimSizeIndexBits), 1))); - Value bcastValShapeTensor = rewriter - .create( - op->getLoc(), bcastValShapeTensorVec) - .getResult(); - bcastVal = rewriter.create( - op->getLoc(), reshapeType, bcastVal, bcastValShapeTensor); + bcastValShapeTensorVec.push_back(mlir::arith::ConstantOp::create( + rewriter, op->getLoc(), + rewriter.getIntegerAttr(rewriter.getIntegerType(dimSizeIndexBits), + 1))); + Value bcastValShapeTensor = + tensor::FromElementsOp::create(rewriter, op->getLoc(), + bcastValShapeTensorVec) + .getResult(); + bcastVal = stablehlo::DynamicReshapeOp::create( + rewriter, op->getLoc(), reshapeType, bcastVal, bcastValShapeTensor); } broadcastedIndices.push_back(bcastVal); } @@ -283,8 +285,8 @@ FailureOr broadcastAndConcatIndices(Operation *op, Value finalIndexTensor = broadcastedIndices[0]; if (broadcastedIndices.size() > 1) { RankedTensorType concatTy = RankedTensorType::get(concatShape, indexElemTy); - finalIndexTensor = rewriter.create( - op->getLoc(), concatTy, ValueRange(broadcastedIndices), + finalIndexTensor = stablehlo::ConcatenateOp::create( + rewriter, op->getLoc(), concatTy, ValueRange(broadcastedIndices), concatShape.size() - 1); } return finalIndexTensor; @@ -430,9 +432,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!initValue) return failure(); - auto stablehloReduceOp = rewriter.create( - op.getLoc(), gatherOutput, initValue, rewriter.getDenseI64ArrayAttr({0}), - elementTy); + auto stablehloReduceOp = stablehlo::ReduceOp::create( + rewriter, op.getLoc(), gatherOutput, initValue, + rewriter.getDenseI64ArrayAttr({0}), elementTy); Region ®ion = stablehloReduceOp.getBody(); Block &block = region.emplaceBlock(); @@ -447,9 +449,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); - Value addResult = rewriter.create( - op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); - rewriter.create(op->getLoc(), addResult); + Value addResult = + stablehlo::AddOp::create(rewriter, op->getLoc(), blockArgumentTy, + *firstArgument, *secondArgument); + stablehlo::ReturnOp::create(rewriter, op->getLoc(), addResult); } auto outShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, weight); @@ -458,13 +461,14 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "failed to get dimension sizes of the input"); } auto outShapeVec = *outShapeInfo; - auto one = rewriter.create( - op->getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); + auto one = mlir::arith::ConstantOp::create( + rewriter, op->getLoc(), + rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); outShapeVec[0] = one; auto outShapeTensor = - rewriter.create(op->getLoc(), outShapeVec); - auto resultA = rewriter.create( - loc, getTypeConverter()->convertType(op.getType(0)), + mlir::tensor::FromElementsOp::create(rewriter, op->getLoc(), outShapeVec); + auto resultA = stablehlo::DynamicReshapeOp::create( + rewriter, loc, getTypeConverter()->convertType(op.getType(0)), stablehloReduceOp.getResult(0), outShapeTensor); RankedTensorType resultType = cast( @@ -554,12 +558,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "failed to get dim sizes of `index` param"); } - auto one = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); + auto one = arith::ConstantOp::create( + rewriter, loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); auto toConcatIndexShapeValueVec = *indexShapeInfo; toConcatIndexShapeValueVec.push_back(one); auto toConcatIndexShape = - rewriter.create(loc, toConcatIndexShapeValueVec); + tensor::FromElementsOp::create(rewriter, loc, toConcatIndexShapeValueVec); auto indexShape = indexType.getShape(); SmallVector toConcatIndexShapeVec(indexShape.begin(), @@ -571,16 +575,16 @@ LogicalResult ConvertAtenOp::matchAndRewrite( SmallVector toConcat; for (int64_t i = 0; i < inputType.getRank(); ++i) { if (i == dim) { - toConcat.push_back(rewriter.create( - loc, toConcatIndexType, index, toConcatIndexShape)); + toConcat.push_back(stablehlo::DynamicReshapeOp::create( + rewriter, loc, toConcatIndexType, index, toConcatIndexShape)); } else { - toConcat.push_back(rewriter.create( - loc, toConcatIndexType, toConcatIndexShape, + toConcat.push_back(stablehlo::DynamicIotaOp::create( + rewriter, loc, toConcatIndexType, toConcatIndexShape, rewriter.getI64IntegerAttr(i))); } } - auto gatherIndicies = rewriter.create( - loc, toConcat, static_cast(inputType.getRank())); + auto gatherIndicies = stablehlo::ConcatenateOp::create( + rewriter, loc, toConcat, static_cast(inputType.getRank())); SmallVector sliceSizes(inputType.getRank(), 1); int64_t indexVecDim = inputType.getRank(); @@ -674,7 +678,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto constAttr = DenseElementsAttr::get( RankedTensorType::get(shape, rewriter.getIntegerType(64)), indices); auto const_op = - rewriter.create(loc, constType, constAttr); + stablehlo::ConstantOp::create(rewriter, loc, constType, constAttr); Value scatterIndices = const_op.getResult(); SmallVector updateWindowDims; @@ -695,8 +699,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( /*indexVectorDim=*/1); Value src = adaptor.getSrc(); - auto scatterOp = rewriter.create( - loc, resultType, input, scatterIndices, src, scatterArgs, false, false); + auto scatterOp = stablehlo::ScatterOp::create(rewriter, loc, resultType, + input, scatterIndices, src, + scatterArgs, false, false); Block &block = scatterOp.getUpdateComputation().emplaceBlock(); auto blockArgumentType = @@ -709,7 +714,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); - rewriter.create(loc, *rhs); + stablehlo::ReturnOp::create(rewriter, loc, *rhs); } rewriter.replaceOp(op, scatterOp.getResults()); @@ -759,31 +764,31 @@ class ConvertAtenScatterOp : public ConvertAtenOp { // leading dimensions. PyTorch has guaranteed that src tensor size will not // be smaller than that of index tensor. REF: // https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html#torch.Tensor.scatter_ - auto zero = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); - auto one = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); + auto zero = arith::ConstantOp::create( + rewriter, loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); + auto one = arith::ConstantOp::create( + rewriter, loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); SmallVector sliceIndicies(srcType.getRank(), zero); SmallVector sliceStrides(srcType.getRank(), one); auto sliceIndiciesValue = - rewriter.create(loc, sliceIndicies); + tensor::FromElementsOp::create(rewriter, loc, sliceIndicies); auto sliceStridesValue = - rewriter.create(loc, sliceStrides); + tensor::FromElementsOp::create(rewriter, loc, sliceStrides); auto sliceLimitIndiciesValue = - rewriter.create(loc, *indexShapeInfo); + tensor::FromElementsOp::create(rewriter, loc, *indexShapeInfo); auto newSrcType = RankedTensorType::get(indexType.getShape(), srcType.getElementType()); - src = rewriter.create( - loc, newSrcType, src, sliceIndiciesValue, sliceLimitIndiciesValue, - sliceStridesValue); + src = stablehlo::RealDynamicSliceOp::create( + rewriter, loc, newSrcType, src, sliceIndiciesValue, + sliceLimitIndiciesValue, sliceStridesValue); // generate scatter indicies for stablehlo::Scatter op. auto toConcatIndexShapeValueVec = *indexShapeInfo; toConcatIndexShapeValueVec.push_back(one); - auto toConcatIndexShape = rewriter.create( - loc, toConcatIndexShapeValueVec); + auto toConcatIndexShape = tensor::FromElementsOp::create( + rewriter, loc, toConcatIndexShapeValueVec); auto indexShape = indexType.getShape(); SmallVector toConcatIndexShapeVec(indexShape.begin(), @@ -795,17 +800,17 @@ class ConvertAtenScatterOp : public ConvertAtenOp { SmallVector toConcat; for (int64_t i = 0; i < inputType.getRank(); ++i) { if (i == dim) { - toConcat.push_back(rewriter.create( - loc, toConcatIndexType, index, toConcatIndexShape)); + toConcat.push_back(stablehlo::DynamicReshapeOp::create( + rewriter, loc, toConcatIndexType, index, toConcatIndexShape)); } else { - toConcat.push_back(rewriter.create( - loc, toConcatIndexType, toConcatIndexShape, + toConcat.push_back(stablehlo::DynamicIotaOp::create( + rewriter, loc, toConcatIndexType, toConcatIndexShape, rewriter.getI64IntegerAttr(i))); } } - auto scatterIndicies = rewriter.create( - loc, toConcat, static_cast(inputType.getRank())); + auto scatterIndicies = stablehlo::ConcatenateOp::create( + rewriter, loc, toConcat, static_cast(inputType.getRank())); SmallVector sliceSizes(inputType.getRank(), 1); // generate ScatterDimensionNumbers for stablehlo::Scatter op. @@ -825,9 +830,9 @@ class ConvertAtenScatterOp : public ConvertAtenOp { /*scatterDimsToOperandDim=*/scatterDimOperandDimMap, /*indexVectorDim=*/indexVecDim); - auto stablehloScatterOp = rewriter.create( - loc, inputType, input, scatterIndicies, src, scatterDimensionNumbers, - false, false); + auto stablehloScatterOp = stablehlo::ScatterOp::create( + rewriter, loc, inputType, input, scatterIndicies, src, + scatterDimensionNumbers, false, false); // config update computation function: just return the element from src. Block &block = stablehloScatterOp.getUpdateComputation().emplaceBlock(); @@ -844,11 +849,11 @@ class ConvertAtenScatterOp : public ConvertAtenOp { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); if (reduceType == 0) { - rewriter.create(loc, *rhsArg); + stablehlo::ReturnOp::create(rewriter, loc, *rhsArg); } else if (reduceType == 1) { - Value res = rewriter.create(loc, blockArgumentType, - *lhsArg, *rhsArg); - rewriter.create(loc, res); + Value res = stablehlo::AddOp::create(rewriter, loc, blockArgumentType, + *lhsArg, *rhsArg); + stablehlo::ReturnOp::create(rewriter, loc, res); } } @@ -1000,12 +1005,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( valuesType = RankedTensorType::get(expectedValuesShape, valuesType.getElementType()); - values = - hlo::promoteAndBroadcast(rewriter, values, valuesType, - rewriter - .create( - op->getLoc(), expectedValuesShapeTensors) - .getResult()); + values = hlo::promoteAndBroadcast( + rewriter, values, valuesType, + tensor::FromElementsOp::create(rewriter, op->getLoc(), + expectedValuesShapeTensors) + .getResult()); valueRank = valuesType.getRank(); valuesShape = valuesType.getShape(); @@ -1030,9 +1034,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( /*scatterDimsToOperandDim=*/scatterDimOperandDimMap, /*indexVectorDim=*/indexVecDim); - auto stablehloScatterOp = rewriter.create( - loc, outType, input, scatterIndices, values, scatterDimensionNumbers, - false, false); + auto stablehloScatterOp = stablehlo::ScatterOp::create( + rewriter, loc, outType, input, scatterIndices, values, + scatterDimensionNumbers, false, false); // configure update computation function. Block &block = stablehloScatterOp.getUpdateComputation().emplaceBlock(); @@ -1049,11 +1053,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); if (!accumulate) { - rewriter.create(loc, *rhsArg); + stablehlo::ReturnOp::create(rewriter, loc, *rhsArg); } else { - Value out = rewriter.create(loc, blockArgumentType, - *lhsArg, *rhsArg); - rewriter.create(loc, out); + Value out = stablehlo::AddOp::create(rewriter, loc, blockArgumentType, + *lhsArg, *rhsArg); + stablehlo::ReturnOp::create(rewriter, loc, out); } } @@ -1078,8 +1082,8 @@ static Value getConstantLike(OpBuilder &b, Location loc, T constant, return complex::NumberAttr::get(complexTy, constant, 0); llvm_unreachable("unhandled element type"); }; - return b.create(loc, cast(getAttr()), - val); + return mlir::chlo::ConstantLikeOp::create(b, loc, cast(getAttr()), + val); } template @@ -1089,7 +1093,7 @@ static Value getConstTensor(ConversionPatternRewriter &rewriter, Operation *op, Location loc = op->getLoc(); RankedTensorType valueType = RankedTensorType::get(shape, ty); auto valueAttr = DenseElementsAttr::get(valueType, values); - return rewriter.create(loc, valueType, valueAttr); + return stablehlo::ConstantOp::create(rewriter, loc, valueType, valueAttr); } template @@ -1119,11 +1123,11 @@ static Value unnormalize(ConversionPatternRewriter &rewriter, Operation *op, // use chlo::BroadcastMulOp to multiply constMul with coords. DenseI64ArrayAttr bcastDimensions; - Value mulResult = rewriter.create(loc, coords, constMul, - bcastDimensions); + Value mulResult = chlo::BroadcastMulOp::create(rewriter, loc, coords, + constMul, bcastDimensions); // use chlo::BroadcastAddOp to add constOfs to mulResult. - Value result = rewriter.create(loc, mulResult, constOfs, - bcastDimensions); + Value result = chlo::BroadcastAddOp::create(rewriter, loc, mulResult, + constOfs, bcastDimensions); return result; } @@ -1172,20 +1176,22 @@ static Value inBoundsCond(ConversionPatternRewriter &rewriter, Operation *op, chlo::ComparisonDirectionAttr::get(rewriter.getContext(), chlo::ComparisonDirection::GE); DenseI64ArrayAttr bcastDimensions; - Value cond1 = rewriter.create( - loc, xs, zero, bcastDimensions, compareGEAttr, compareTypeAttr); - Value cond2 = rewriter.create( - loc, xs, iwFloatValue, bcastDimensions, compareLTAttr, compareTypeAttr); - Value cond3 = rewriter.create( - loc, ys, zero, bcastDimensions, compareGEAttr, compareTypeAttr); - Value cond4 = rewriter.create( - loc, ys, ihFloatValue, bcastDimensions, compareLTAttr, compareTypeAttr); - Value cond5 = - rewriter.create(loc, cond1, cond2, bcastDimensions); - Value cond6 = - rewriter.create(loc, cond3, cond4, bcastDimensions); - return rewriter.create(loc, cond5, cond6, - bcastDimensions); + Value cond1 = chlo::BroadcastCompareOp::create( + rewriter, loc, xs, zero, bcastDimensions, compareGEAttr, compareTypeAttr); + Value cond2 = chlo::BroadcastCompareOp::create( + rewriter, loc, xs, iwFloatValue, bcastDimensions, compareLTAttr, + compareTypeAttr); + Value cond3 = chlo::BroadcastCompareOp::create( + rewriter, loc, ys, zero, bcastDimensions, compareGEAttr, compareTypeAttr); + Value cond4 = chlo::BroadcastCompareOp::create( + rewriter, loc, ys, ihFloatValue, bcastDimensions, compareLTAttr, + compareTypeAttr); + Value cond5 = chlo::BroadcastAndOp::create(rewriter, loc, cond1, cond2, + bcastDimensions); + Value cond6 = chlo::BroadcastAndOp::create(rewriter, loc, cond3, cond4, + bcastDimensions); + return chlo::BroadcastAndOp::create(rewriter, loc, cond5, cond6, + bcastDimensions); } // def clip(xs: Tensor, ys: Tensor, ws: Tensor) -> TensorSequenceType: // cond = in_bounds_cond(xs, ys) @@ -1205,31 +1211,32 @@ SmallVector clip(ConversionPatternRewriter &rewriter, Operation *op, auto indexElemTy = rewriter.getI64Type(); auto indexTy = RankedTensorType::get(mlir::ArrayRef{1}, indexElemTy); - Value zeroIntValue = rewriter.create( - loc, indexTy, DenseIntElementsAttr::get(indexTy, ArrayRef{0})); + Value zeroIntValue = stablehlo::ConstantOp::create( + rewriter, loc, indexTy, + DenseIntElementsAttr::get(indexTy, ArrayRef{0})); APFloat zeroAPFloat = APFloat(cast(elemTy).getFloatSemantics(), 0); Value zeroFloatValue = getConstScalarTensor(rewriter, op, zeroAPFloat, elemTy); Value cond = inBoundsCond(rewriter, op, xs, ys, iH, iW, elemTy); - Value xsInt = rewriter.create(loc, xs, indexElemTy); - Value ysInt = rewriter.create(loc, ys, indexElemTy); + Value xsInt = stablehlo::ConvertOp::create(rewriter, loc, xs, indexElemTy); + Value ysInt = stablehlo::ConvertOp::create(rewriter, loc, ys, indexElemTy); - Value selectXs = rewriter.create( - loc, ArrayRef{cond, xsInt, zeroIntValue}); - Value selectYs = rewriter.create( - loc, ArrayRef{cond, ysInt, zeroIntValue}); - Value selectWs = rewriter.create( - loc, ArrayRef{cond, ws, zeroFloatValue}); + Value selectXs = chlo::BroadcastSelectOp::create( + rewriter, loc, ArrayRef{cond, xsInt, zeroIntValue}); + Value selectYs = chlo::BroadcastSelectOp::create( + rewriter, loc, ArrayRef{cond, ysInt, zeroIntValue}); + Value selectWs = chlo::BroadcastSelectOp::create( + rewriter, loc, ArrayRef{cond, ws, zeroFloatValue}); SmallVector sizes = {N, 1, oH, oW}; - Value reshapedXs = rewriter.create( - loc, RankedTensorType::get(sizes, indexElemTy), selectXs); - Value reshapedYs = rewriter.create( - loc, RankedTensorType::get(sizes, indexElemTy), selectYs); - Value reshapedWs = rewriter.create( - loc, RankedTensorType::get(sizes, elemTy), selectWs); + Value reshapedXs = stablehlo::ReshapeOp::create( + rewriter, loc, RankedTensorType::get(sizes, indexElemTy), selectXs); + Value reshapedYs = stablehlo::ReshapeOp::create( + rewriter, loc, RankedTensorType::get(sizes, indexElemTy), selectYs); + Value reshapedWs = stablehlo::ReshapeOp::create( + rewriter, loc, RankedTensorType::get(sizes, elemTy), selectWs); return SmallVector{reshapedXs, reshapedYs, reshapedWs}; } @@ -1284,13 +1291,13 @@ Value getSummand(ConversionPatternRewriter &rewriter, Operation *op, } } - Value gather = rewriter.create( - loc, input, gatherIndices, dimsAttr, - rewriter.getDenseI64ArrayAttr(sliceSizes)); + Value gather = + stablehlo::GatherOp::create(rewriter, loc, input, gatherIndices, dimsAttr, + rewriter.getDenseI64ArrayAttr(sliceSizes)); // use chlo::BroadcastMulOp to multiply idxW with gather. DenseI64ArrayAttr bcastDimensions; - return rewriter.create(loc, gather, idxW, - bcastDimensions); + return chlo::BroadcastMulOp::create(rewriter, loc, gather, idxW, + bcastDimensions); } } // namespace @@ -1348,41 +1355,45 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Type indexElemTy = rewriter.getI64Type(); RankedTensorType indexTy = RankedTensorType::get(mlir::ArrayRef{1}, indexElemTy); - Value constN = rewriter.create( - loc, indexTy, DenseIntElementsAttr::get(indexTy, {N})); - Value constC = rewriter.create( - loc, indexTy, DenseIntElementsAttr::get(indexTy, {C})); + Value constN = stablehlo::ConstantOp::create( + rewriter, loc, indexTy, DenseIntElementsAttr::get(indexTy, {N})); + Value constC = stablehlo::ConstantOp::create( + rewriter, loc, indexTy, DenseIntElementsAttr::get(indexTy, {C})); APFloat one = APFloat(cast(elemTy).getFloatSemantics(), 1); APFloat zero = APFloat(cast(elemTy).getFloatSemantics(), 0); Value constOneFloat = getConstScalarTensor(rewriter, op, one, elemTy); - auto NidxFlatten = rewriter.create( - loc, RankedTensorType::get(mlir::ArrayRef{N}, indexElemTy), - constN, 0); - auto CidxFlatten = rewriter.create( - loc, RankedTensorType::get(mlir::ArrayRef{C}, indexElemTy), - constC, 0); + auto NidxFlatten = stablehlo::DynamicIotaOp::create( + rewriter, loc, + RankedTensorType::get(mlir::ArrayRef{N}, indexElemTy), constN, + 0); + auto CidxFlatten = stablehlo::DynamicIotaOp::create( + rewriter, loc, + RankedTensorType::get(mlir::ArrayRef{C}, indexElemTy), constC, + 0); // Reshape NidxFlatten to 4D tensor (N, 1, 1, 1) auto NidxSizes = mlir::SmallVector{N, 1, 1, 1}; - auto Nidx = rewriter.create( - loc, RankedTensorType::get(NidxSizes, indexElemTy), NidxFlatten); + auto Nidx = stablehlo::ReshapeOp::create( + rewriter, loc, RankedTensorType::get(NidxSizes, indexElemTy), + NidxFlatten); // Reshape CidxFlatten to 4D tensor (1, C, 1, 1) auto CidxSizes = mlir::SmallVector{1, C, 1, 1}; - auto Cidx = rewriter.create( - loc, RankedTensorType::get(CidxSizes, indexElemTy), CidxFlatten); + auto Cidx = stablehlo::ReshapeOp::create( + rewriter, loc, RankedTensorType::get(CidxSizes, indexElemTy), + CidxFlatten); llvm::SmallVector stride(4, 1); - auto gridX = rewriter.create( - loc, + auto gridX = stablehlo::SliceOp::create( + rewriter, loc, RankedTensorType::get(mlir::SmallVector{N, oH, oW, 1}, gridTy.getElementType()), grid, mlir::SmallVector{0, 0, 0, 0}, mlir::SmallVector{N, oH, oW, 1}, stride); - auto gridY = rewriter.create( - loc, + auto gridY = stablehlo::SliceOp::create( + rewriter, loc, RankedTensorType::get(mlir::SmallVector{N, oH, oW, 1}, gridTy.getElementType()), grid, mlir::SmallVector{0, 0, 0, 1}, @@ -1390,26 +1401,28 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // squeeze last dimension auto gridXshape = mlir::SmallVector{N, oH, oW}; - auto gridXReshape = rewriter.create( - loc, RankedTensorType::get(gridXshape, gridTy.getElementType()), gridX); - auto gridYReshape = rewriter.create( - loc, RankedTensorType::get(gridXshape, gridTy.getElementType()), gridY); + auto gridXReshape = stablehlo::ReshapeOp::create( + rewriter, loc, RankedTensorType::get(gridXshape, gridTy.getElementType()), + gridX); + auto gridYReshape = stablehlo::ReshapeOp::create( + rewriter, loc, RankedTensorType::get(gridXshape, gridTy.getElementType()), + gridY); if (interpolationMode == 0) { Value ix = computeSourceIndex(rewriter, op, gridXReshape, iW, elemTy, paddingMode, alignCorners); Value iy = computeSourceIndex(rewriter, op, gridYReshape, iH, elemTy, paddingMode, alignCorners); - Value ix_nw = rewriter.create(loc, ix); - Value iy_nw = rewriter.create(loc, iy); + Value ix_nw = stablehlo::FloorOp::create(rewriter, loc, ix); + Value iy_nw = stablehlo::FloorOp::create(rewriter, loc, iy); DenseI64ArrayAttr bcastDimensions; - Value ix_ne = rewriter.create( - loc, ix_nw, constOneFloat, bcastDimensions); + Value ix_ne = chlo::BroadcastAddOp::create(rewriter, loc, ix_nw, + constOneFloat, bcastDimensions); Value iy_ne = iy_nw; Value ix_sw = ix_nw; - Value iy_sw = rewriter.create( - loc, iy_nw, constOneFloat, bcastDimensions); + Value iy_sw = chlo::BroadcastAddOp::create(rewriter, loc, iy_nw, + constOneFloat, bcastDimensions); Value ix_se = ix_ne; Value iy_se = iy_sw; @@ -1417,25 +1430,25 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // w_ne = (ix - ix_sw) * (iy_sw - iy) // w_sw = (ix_ne - ix) * (iy - iy_ne) // w_se = (ix - ix_nw) * (iy - iy_nw) - Value w_nw = rewriter.create( - loc, - rewriter.create(loc, ix_se, ix, bcastDimensions), - rewriter.create(loc, iy_se, iy, bcastDimensions), + Value w_nw = chlo::BroadcastMulOp::create( + rewriter, loc, + chlo::BroadcastSubOp::create(rewriter, loc, ix_se, ix, bcastDimensions), + chlo::BroadcastSubOp::create(rewriter, loc, iy_se, iy, bcastDimensions), bcastDimensions); - Value w_ne = rewriter.create( - loc, - rewriter.create(loc, ix, ix_sw, bcastDimensions), - rewriter.create(loc, iy_sw, iy, bcastDimensions), + Value w_ne = chlo::BroadcastMulOp::create( + rewriter, loc, + chlo::BroadcastSubOp::create(rewriter, loc, ix, ix_sw, bcastDimensions), + chlo::BroadcastSubOp::create(rewriter, loc, iy_sw, iy, bcastDimensions), bcastDimensions); - Value w_sw = rewriter.create( - loc, - rewriter.create(loc, ix_ne, ix, bcastDimensions), - rewriter.create(loc, iy, iy_ne, bcastDimensions), + Value w_sw = chlo::BroadcastMulOp::create( + rewriter, loc, + chlo::BroadcastSubOp::create(rewriter, loc, ix_ne, ix, bcastDimensions), + chlo::BroadcastSubOp::create(rewriter, loc, iy, iy_ne, bcastDimensions), bcastDimensions); - Value w_se = rewriter.create( - loc, - rewriter.create(loc, ix, ix_nw, bcastDimensions), - rewriter.create(loc, iy, iy_nw, bcastDimensions), + Value w_se = chlo::BroadcastMulOp::create( + rewriter, loc, + chlo::BroadcastSubOp::create(rewriter, loc, ix, ix_nw, bcastDimensions), + chlo::BroadcastSubOp::create(rewriter, loc, iy, iy_nw, bcastDimensions), bcastDimensions); Value summand_nw = @@ -1452,17 +1465,17 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Nidx, Cidx, outTy, elemTy, options.dimSizeIndexBits); // summand_nw + summand_ne + summand_sw + summand_se - Value sum = rewriter.create(loc, summand_nw, summand_ne); - sum = rewriter.create(loc, sum, summand_sw); - sum = rewriter.create(loc, sum, summand_se); + Value sum = stablehlo::AddOp::create(rewriter, loc, summand_nw, summand_ne); + sum = stablehlo::AddOp::create(rewriter, loc, sum, summand_sw); + sum = stablehlo::AddOp::create(rewriter, loc, sum, summand_se); rewriter.replaceOp(op, sum); } else if (interpolationMode == 1) { Value ix = computeSourceIndex(rewriter, op, gridXReshape, iW, elemTy, paddingMode, alignCorners); Value iy = computeSourceIndex(rewriter, op, gridYReshape, iH, elemTy, paddingMode, alignCorners); - Value ix_round = rewriter.create(loc, ix); - Value iy_round = rewriter.create(loc, iy); + Value ix_round = stablehlo::RoundOp::create(rewriter, loc, ix); + Value iy_round = stablehlo::RoundOp::create(rewriter, loc, iy); Value oneTensor = getConstantLike(rewriter, loc, 1.0, ix_round); Value summand = getSummand(rewriter, op, input, ix_round, iy_round, oneTensor, N, oH, oW, iH, iW, Nidx, Cidx, outTy, diff --git a/lib/Conversion/TorchToStablehlo/Linear.cpp b/lib/Conversion/TorchToStablehlo/Linear.cpp index b42ed7cc7722..56094c8d0f52 100644 --- a/lib/Conversion/TorchToStablehlo/Linear.cpp +++ b/lib/Conversion/TorchToStablehlo/Linear.cpp @@ -33,15 +33,16 @@ Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor, ArrayRef broadcastDims) { auto tensorTy = dyn_cast(tensor.getType()); auto loc = op->getLoc(); - Value stablehloShape = rewriter.create(loc, dimSizes); + Value stablehloShape = + tensor::FromElementsOp::create(rewriter, loc, dimSizes); RankedTensorType outTy = RankedTensorType::get(shape, tensorTy.getElementType()); auto broadcastAttr = rewriter.getDenseI64ArrayAttr(broadcastDims); - auto broadcast = rewriter.create( - loc, outTy, tensor, stablehloShape, broadcastAttr); + auto broadcast = stablehlo::DynamicBroadcastInDimOp::create( + rewriter, loc, outTy, tensor, stablehloShape, broadcastAttr); return broadcast; } @@ -59,8 +60,8 @@ Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input, } auto outTy = RankedTensorType::get(newShape, inputTy.getElementType()); - auto result = rewriter.create(op->getLoc(), outTy, - input, transDims); + auto result = stablehlo::TransposeOp::create(rewriter, op->getLoc(), outTy, + input, transDims); return result.getResult(); } @@ -85,12 +86,12 @@ RankedTensorType castContractingDim(PatternRewriter &rewriter, Operation *op, rhsContractingDimSize >= 0) { lhsShape[lhsContractingDim] = rhsContractingDimSize; auto newRankTy = RankedTensorType::get(lhsShape, lhsTy.getElementType()); - lhs = rewriter.create(op->getLoc(), newRankTy, lhs); + lhs = tensor::CastOp::create(rewriter, op->getLoc(), newRankTy, lhs); } else if (rhsContractingDimSize == ShapedType::kDynamic && lhsContractingDimSize >= 0) { rhsShape[rhsContractingDim] = lhsContractingDimSize; auto newRankTy = RankedTensorType::get(rhsShape, rhsTy.getElementType()); - rhs = rewriter.create(op->getLoc(), newRankTy, rhs); + rhs = tensor::CastOp::create(rewriter, op->getLoc(), newRankTy, rhs); } } SmallVector outShape; @@ -278,8 +279,8 @@ class ConvertAtenMatmulBaseOp : public ConvertAtenOp { if (lhsRank <= 2 && rhsRank <= 2) { auto tensorType = ConvertAtenOp::getTypeConverter()->convertType(op.getType()); - output = rewriter.create(op->getLoc(), tensorType, lhs, - rhs, nullptr); + output = stablehlo::DotOp::create(rewriter, op->getLoc(), tensorType, lhs, + rhs, nullptr); return success(); } @@ -323,11 +324,10 @@ class ConvertAtenMatmulBaseOp : public ConvertAtenOp { auto outTy = castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim, lhsContractingDim, rhsContractingDim); - output = rewriter - .create(op->getLoc(), outTy, lhs, rhs, - dotDimensionNumbers, nullptr, - nullptr) - .getResult(); + output = + stablehlo::DotGeneralOp::create(rewriter, op->getLoc(), outTy, lhs, rhs, + dotDimensionNumbers, nullptr, nullptr) + .getResult(); return success(); } @@ -494,16 +494,17 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { /*rhsBatchingDimensions=*/batchDims, /*lhsContractingDimensions=*/{lhsContractingDim}, /*rhsContractingDimensions=*/{rhsContractingDim}); - Value matmulOutput = rewriter.create( - op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr, nullptr); + Value matmulOutput = + stablehlo::DotGeneralOp::create(rewriter, op->getLoc(), outTy, lhs, rhs, + dotDimensionNumbers, nullptr, nullptr); Value matmulPlusBias = matmulOutput; if (!isa(biasTy)) { // Bias addition broadcasts to the matmul output shape. - matmulPlusBias = rewriter - .create( - op->getLoc(), outTy, matmulOutput, bias, nullptr) - .getResult(); + matmulPlusBias = + chlo::BroadcastAddOp::create(rewriter, op->getLoc(), outTy, + matmulOutput, bias, nullptr) + .getResult(); } auto resultTy = @@ -530,12 +531,13 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { std::copy(weightShape.begin(), weightShape.end(), weightShapeInt.begin()); // 1. [H, W, ..., OC, IC] => [H, W, ..., OC, G, IC//G] - Value GValue = rewriter.create( - op->getLoc(), rewriter.getIntegerAttr(rewriter.getIndexType(), groups)); - Value ICDivGValue = rewriter.create( - op->getLoc(), weightShapeVec[rank - 1], GValue); - Value OCMulGValue = rewriter.create( - op->getLoc(), weightShapeVec[rank - 2], GValue); + Value GValue = mlir::arith::ConstantOp::create( + rewriter, op->getLoc(), + rewriter.getIntegerAttr(rewriter.getIndexType(), groups)); + Value ICDivGValue = mlir::arith::DivSIOp::create( + rewriter, op->getLoc(), weightShapeVec[rank - 1], GValue); + Value OCMulGValue = mlir::arith::MulIOp::create( + rewriter, op->getLoc(), weightShapeVec[rank - 2], GValue); weightShapeVec[rank - 1] = ICDivGValue; weightShapeVec.insert(weightShapeVec.end() - 1, GValue); @@ -545,19 +547,20 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { weightShapeInt[rank - 1] /= groups; weightShapeInt.insert(weightShapeInt.end() - 1, groups); } - Value weightShapeTensor = rewriter.create( - op->getLoc(), weightShapeVec); - weight = rewriter.create( - op->getLoc(), RankedTensorType::get(weightShapeInt, weightElemTy), - weight, weightShapeTensor); + Value weightShapeTensor = mlir::tensor::FromElementsOp::create( + rewriter, op->getLoc(), weightShapeVec); + weight = stablehlo::DynamicReshapeOp::create( + rewriter, op->getLoc(), + RankedTensorType::get(weightShapeInt, weightElemTy), weight, + weightShapeTensor); // 2. [H, W, ..., OC, G, IC//G] => [H, W, ..., G, OC, IC//G] std::vector transposeDims(rank + 1); for (int64_t i = 0; i <= rank; i++) transposeDims[i] = i; std::swap(transposeDims[rank - 1], transposeDims[rank - 2]); - weight = rewriter.create(op->getLoc(), weight, - transposeDims); + weight = stablehlo::TransposeOp::create(rewriter, op->getLoc(), weight, + transposeDims); // 3. [H, W, ..., G, OC, IC//G] => [H, W, ..., G*OC, IC//G] weightShapeInt.erase(weightShapeInt.end() - 2); @@ -566,11 +569,12 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { } weightShapeVec.erase(weightShapeVec.end() - 2); weightShapeVec[weightShapeVec.size() - 2] = OCMulGValue; - weightShapeTensor = rewriter.create( - op->getLoc(), weightShapeVec); - weight = rewriter.create( - op->getLoc(), RankedTensorType::get(weightShapeInt, weightElemTy), - weight, weightShapeTensor); + weightShapeTensor = mlir::tensor::FromElementsOp::create( + rewriter, op->getLoc(), weightShapeVec); + weight = stablehlo::DynamicReshapeOp::create( + rewriter, op->getLoc(), + RankedTensorType::get(weightShapeInt, weightElemTy), weight, + weightShapeTensor); return weight; } @@ -609,10 +613,10 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { auto reverseDim = llvm::to_vector<4>(llvm::seq(0, kernelDims)); auto transposeTy = RankedTensorType::get(transposeShape, weightTy.getElementType()); - auto transposeOp = rewriter.create( - op->getLoc(), transposeTy, weight, perm); - auto reverseOp = rewriter.create( - op->getLoc(), transposeOp, reverseDim); + auto transposeOp = stablehlo::TransposeOp::create( + rewriter, op->getLoc(), transposeTy, weight, perm); + auto reverseOp = stablehlo::ReverseOp::create(rewriter, op->getLoc(), + transposeOp, reverseDim); // Prepare for transposed convolution SmallVector stablehloStrideVec(nSpatialDims, 1); @@ -664,8 +668,8 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { } // Create transposed convolution - auto transposedConvOp = rewriter.create( - op->getLoc(), convOutTy, input, weightInput, stablehloStride, + auto transposedConvOp = stablehlo::ConvolutionOp::create( + rewriter, op->getLoc(), convOutTy, input, weightInput, stablehloStride, stablehloPadding, stablehloLhsDilation, stablehloRhsDilation, windowReversal, dimensionNumbers, static_cast(groups), 1, precisionConfig); @@ -712,8 +716,8 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { DenseBoolArrayAttr windowReversal; ArrayAttr precisionConfig; - auto stablehloConvOp = rewriter.create( - op->getLoc(), outType, input, weight, stablehloWindowStride, + auto stablehloConvOp = stablehlo::ConvolutionOp::create( + rewriter, op->getLoc(), outType, input, weight, stablehloWindowStride, stablehloPadding, stablehloLhsDilation, stablehloRhsDilation, windowReversal, dimensionNumbers, static_cast(groups), 1, precisionConfig); diff --git a/lib/Conversion/TorchToStablehlo/Pooling.cpp b/lib/Conversion/TorchToStablehlo/Pooling.cpp index 915a58413e46..5c0ecb19c5a4 100644 --- a/lib/Conversion/TorchToStablehlo/Pooling.cpp +++ b/lib/Conversion/TorchToStablehlo/Pooling.cpp @@ -39,14 +39,14 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy, constType, {APFloat::getZero( cast(elementTy).getFloatSemantics(), /*negative=*/false)}); - return rewriter.create(op->getLoc(), constType, - constAttr); + return stablehlo::ConstantOp::create(rewriter, op->getLoc(), constType, + constAttr); } else if (isa(elementTy) && elementTy.getIntOrFloatBitWidth() != 8) { auto constAttr = DenseElementsAttr::get( constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())}); - return rewriter.create(op->getLoc(), constType, - constAttr); + return stablehlo::ConstantOp::create(rewriter, op->getLoc(), constType, + constAttr); } } @@ -58,15 +58,15 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy, constType, {APFloat::getInf(cast(elementTy).getFloatSemantics(), /*negative=*/true)}); - return rewriter.create(op->getLoc(), constType, - constAttr); + return stablehlo::ConstantOp::create(rewriter, op->getLoc(), constType, + constAttr); } else if (isa(elementTy) && elementTy.getIntOrFloatBitWidth() != 8) { auto constAttr = DenseElementsAttr::get( constType, {APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())}); - return rewriter.create(op->getLoc(), constType, - constAttr); + return stablehlo::ConstantOp::create(rewriter, op->getLoc(), constType, + constAttr); } } op->emitError("unimplemented lowering in AtenPoolingOp"); @@ -173,22 +173,21 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "failed to get dimension sizes of the input"); } auto inputShapeVec = *inputShapeInfo; - auto inputShapeTensor = rewriter.create( - op->getLoc(), inputShapeVec); + auto inputShapeTensor = mlir::tensor::FromElementsOp::create( + rewriter, op->getLoc(), inputShapeVec); // no need to reshape here for max_pool_1d. Need to make sure the iota // dimension. dim=inputRank-2 or dim=inputRank-1? auto indexTensor = - rewriter - .create( - op->getLoc(), - RankedTensorType::get(inputShape, rewriter.getI64Type()), - inputShapeTensor, static_cast(inputRank - 1)) + stablehlo::DynamicIotaOp::create( + rewriter, op->getLoc(), + RankedTensorType::get(inputShape, rewriter.getI64Type()), + inputShapeTensor, static_cast(inputRank - 1)) .getResult(); Value initIdx = hlo::getConstTensor(rewriter, op, {0}, {}).value(); - auto reduceWindowOp = rewriter.create( - op->getLoc(), mlir::TypeRange{outValTy, outIdxTy}, + auto reduceWindowOp = stablehlo::ReduceWindowOp::create( + rewriter, op->getLoc(), mlir::TypeRange{outValTy, outIdxTy}, mlir::ValueRange{input, indexTensor}, mlir::ValueRange{initVal, initIdx}, windowDimensions, windowStrides, baseDilations, windowDilations, pad); @@ -226,25 +225,25 @@ LogicalResult ConvertAtenOp::matchAndRewrite( OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); - Value compareGeResult = rewriter.create( - op->getLoc(), compareResultType, *firstValArg, *secondValArg, + Value compareGeResult = stablehlo::CompareOp::create( + rewriter, op->getLoc(), compareResultType, *firstValArg, *secondValArg, compareGeDirectionAttr, compareTypeAttr); - Value retValResult = rewriter.create( - op->getLoc(), compareGeResult, *firstValArg, *secondValArg); + Value retValResult = stablehlo::SelectOp::create( + rewriter, op->getLoc(), compareGeResult, *firstValArg, *secondValArg); // Get smaller index if compared values are equal. - Value compareEqResult = rewriter.create( - op->getLoc(), compareResultType, *firstValArg, *secondValArg, + Value compareEqResult = stablehlo::CompareOp::create( + rewriter, op->getLoc(), compareResultType, *firstValArg, *secondValArg, compareEqDirectionAttr, compareTypeAttr); - Value minIdx = rewriter.create(op->getLoc(), *firstIdxArg, - *secondIdxArg); - Value idxWithGeVal = rewriter.create( - op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg); - Value retIdxResult = rewriter.create( - op->getLoc(), compareEqResult, minIdx, idxWithGeVal); - - rewriter.create( - op->getLoc(), mlir::ValueRange{retValResult, retIdxResult}); + Value minIdx = stablehlo::MinOp::create(rewriter, op->getLoc(), + *firstIdxArg, *secondIdxArg); + Value idxWithGeVal = stablehlo::SelectOp::create( + rewriter, op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg); + Value retIdxResult = stablehlo::SelectOp::create( + rewriter, op->getLoc(), compareEqResult, minIdx, idxWithGeVal); + + stablehlo::ReturnOp::create(rewriter, op->getLoc(), + mlir::ValueRange{retValResult, retIdxResult}); } rewriter.replaceOp(op, reduceWindowOp.getResults()); @@ -330,17 +329,17 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "failed to get dimension sizes of the input"); } auto inputShapeVec = *inputShapeInfo; - auto inputShapeTensor = rewriter.create( - op->getLoc(), inputShapeVec); + auto inputShapeTensor = mlir::tensor::FromElementsOp::create( + rewriter, op->getLoc(), inputShapeVec); SmallVector initIndexShapeVec; for (int64_t i = 0; i < inputRank - 2; i++) initIndexShapeVec.push_back(inputShapeVec[i]); - initIndexShapeVec.push_back(rewriter.create( - op->getLoc(), inputShapeVec[inputRank - 1], + initIndexShapeVec.push_back(mlir::arith::MulIOp::create( + rewriter, op->getLoc(), inputShapeVec[inputRank - 1], inputShapeVec[inputRank - 2])); - auto initIndexShapeTensor = rewriter.create( - op->getLoc(), initIndexShapeVec); + auto initIndexShapeTensor = mlir::tensor::FromElementsOp::create( + rewriter, op->getLoc(), initIndexShapeVec); SmallVector initIndexShapeForType(inputShape.begin(), inputShape.end() - 2); @@ -353,26 +352,23 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } auto initIndexTensor = - rewriter - .create( - op->getLoc(), - RankedTensorType::get(initIndexShapeForType, - rewriter.getI64Type()), - initIndexShapeTensor, static_cast(inputRank - 2)) + stablehlo::DynamicIotaOp::create( + rewriter, op->getLoc(), + RankedTensorType::get(initIndexShapeForType, rewriter.getI64Type()), + initIndexShapeTensor, static_cast(inputRank - 2)) .getResult(); auto indexTensor = - rewriter - .create( - op->getLoc(), - RankedTensorType::get(inputShape, rewriter.getI64Type()), - initIndexTensor, inputShapeTensor) + stablehlo::DynamicReshapeOp::create( + rewriter, op->getLoc(), + RankedTensorType::get(inputShape, rewriter.getI64Type()), + initIndexTensor, inputShapeTensor) .getResult(); Value initIdx = hlo::getConstTensor(rewriter, op, {0}, {}).value(); - auto reduceWindowOp = rewriter.create( - op->getLoc(), mlir::TypeRange{outValTy, outIdxTy}, + auto reduceWindowOp = stablehlo::ReduceWindowOp::create( + rewriter, op->getLoc(), mlir::TypeRange{outValTy, outIdxTy}, mlir::ValueRange{input, indexTensor}, mlir::ValueRange{initVal, initIdx}, windowDimensions, windowStrides, baseDilations, windowDilations, pad); @@ -410,25 +406,25 @@ LogicalResult ConvertAtenOp::matchAndRewrite( OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); - Value compareGeResult = rewriter.create( - op->getLoc(), compareResultType, *firstValArg, *secondValArg, + Value compareGeResult = stablehlo::CompareOp::create( + rewriter, op->getLoc(), compareResultType, *firstValArg, *secondValArg, compareGeDirectionAttr, compareTypeAttr); - Value retValResult = rewriter.create( - op->getLoc(), compareGeResult, *firstValArg, *secondValArg); + Value retValResult = stablehlo::SelectOp::create( + rewriter, op->getLoc(), compareGeResult, *firstValArg, *secondValArg); // Get smaller index if compared values are equal. - Value compareEqResult = rewriter.create( - op->getLoc(), compareResultType, *firstValArg, *secondValArg, + Value compareEqResult = stablehlo::CompareOp::create( + rewriter, op->getLoc(), compareResultType, *firstValArg, *secondValArg, compareEqDirectionAttr, compareTypeAttr); - Value minIdx = rewriter.create(op->getLoc(), *firstIdxArg, - *secondIdxArg); - Value idxWithGeVal = rewriter.create( - op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg); - Value retIdxResult = rewriter.create( - op->getLoc(), compareEqResult, minIdx, idxWithGeVal); - - rewriter.create( - op->getLoc(), mlir::ValueRange{retValResult, retIdxResult}); + Value minIdx = stablehlo::MinOp::create(rewriter, op->getLoc(), + *firstIdxArg, *secondIdxArg); + Value idxWithGeVal = stablehlo::SelectOp::create( + rewriter, op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg); + Value retIdxResult = stablehlo::SelectOp::create( + rewriter, op->getLoc(), compareEqResult, minIdx, idxWithGeVal); + + stablehlo::ReturnOp::create(rewriter, op->getLoc(), + mlir::ValueRange{retValResult, retIdxResult}); } rewriter.replaceOp(op, reduceWindowOp.getResults()); @@ -558,9 +554,9 @@ class ConvertAtenMaxPoolOp : public ConvertAtenOp { rewriter.getI64Type()), stablehloPadding); - auto reduceWindowOp = rewriter.create( - op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides, - baseDilations, windowDilations, pad); + auto reduceWindowOp = stablehlo::ReduceWindowOp::create( + rewriter, op->getLoc(), outTy, input, initVal, windowDimensions, + windowStrides, baseDilations, windowDilations, pad); Block &block = reduceWindowOp.getBody().emplaceBlock(); @@ -575,9 +571,9 @@ class ConvertAtenMaxPoolOp : public ConvertAtenOp { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); - Value result = rewriter.create(op->getLoc(), *firstArg, - *secondArg); - rewriter.create(op->getLoc(), result); + Value result = stablehlo::MaxOp::create(rewriter, op->getLoc(), *firstArg, + *secondArg); + stablehlo::ReturnOp::create(rewriter, op->getLoc(), result); } rewriter.replaceOp(op, reduceWindowOp.getResults()); @@ -687,9 +683,9 @@ class ConvertAtenAvgPoolOp : public ConvertAtenOp { rewriter.getI64Type()), stablehloPadding); - auto reduceWindowSum = rewriter.create( - op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides, - baseDilations, windowDilations, pad); + auto reduceWindowSum = stablehlo::ReduceWindowOp::create( + rewriter, op->getLoc(), outTy, input, initVal, windowDimensions, + windowStrides, baseDilations, windowDilations, pad); Block &sumBlock = reduceWindowSum.getBody().emplaceBlock(); @@ -705,8 +701,8 @@ class ConvertAtenAvgPoolOp : public ConvertAtenOp { rewriter.setInsertionPointToStart(&sumBlock); Value sumResult = - rewriter.create(op->getLoc(), firstArg, secondArg); - rewriter.create(op->getLoc(), sumResult); + stablehlo::AddOp::create(rewriter, op->getLoc(), firstArg, secondArg); + stablehlo::ReturnOp::create(rewriter, op->getLoc(), sumResult); } // Use kernel size as the divisor @@ -742,17 +738,17 @@ class ConvertAtenAvgPoolOp : public ConvertAtenOp { windowSizeConst = hlo::promoteType(rewriter, op.getLoc(), windowSizeConst, outTy.getElementType()); auto inputShapeVec = *hlo::getDimIndexOfTensor(rewriter, op, input); - auto inputShapeTensor = rewriter.create( - op->getLoc(), inputShapeVec); + auto inputShapeTensor = mlir::tensor::FromElementsOp::create( + rewriter, op->getLoc(), inputShapeVec); - windowSizeConst = rewriter.create( - op->getLoc(), + windowSizeConst = stablehlo::DynamicBroadcastInDimOp::create( + rewriter, op->getLoc(), RankedTensorType::get(inputTy.getShape(), outTy.getElementType()), windowSizeConst, inputShapeTensor, rewriter.getDenseI64ArrayAttr({})); Value zero = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); - auto reduceWindowSize = rewriter.create( - op->getLoc(), RankedTensorType::get(outShape, inputElemTy), + auto reduceWindowSize = stablehlo::ReduceWindowOp::create( + rewriter, op->getLoc(), RankedTensorType::get(outShape, inputElemTy), windowSizeConst, zero, windowDimensions, windowStrides, baseDilations, windowDilations, pad); @@ -770,8 +766,8 @@ class ConvertAtenAvgPoolOp : public ConvertAtenOp { rewriter.setInsertionPointToStart(&sizeBlock); Value sumResult = - rewriter.create(op->getLoc(), firstArg, secondArg); - rewriter.create(op->getLoc(), sumResult); + stablehlo::AddOp::create(rewriter, op->getLoc(), firstArg, secondArg); + stablehlo::ReturnOp::create(rewriter, op->getLoc(), sumResult); } rewriter.replaceOpWithNewOp( @@ -830,9 +826,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.getI64Type()), stablehloPadding); - auto reduceWindowSum = rewriter.create( - op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides, - baseDilations, windowDilations, pad); + auto reduceWindowSum = stablehlo::ReduceWindowOp::create( + rewriter, op->getLoc(), outTy, input, initVal, windowDimensions, + windowStrides, baseDilations, windowDilations, pad); Block &sumBlock = reduceWindowSum.getBody().emplaceBlock(); @@ -848,8 +844,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.setInsertionPointToStart(&sumBlock); Value sumResult = - rewriter.create(op->getLoc(), *firstArg, *secondArg); - rewriter.create(op->getLoc(), sumResult); + stablehlo::AddOp::create(rewriter, op->getLoc(), *firstArg, *secondArg); + stablehlo::ReturnOp::create(rewriter, op->getLoc(), sumResult); } rewriter.replaceOp(op, reduceWindowSum.getResults()); diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index c007ea7a69f5..f66d9e040951 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -104,8 +104,8 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, } if (constAttr != nullptr) { - return rewriter.create(op->getLoc(), constType, - constAttr); + return stablehlo::ConstantOp::create(rewriter, op->getLoc(), constType, + constAttr); } op->emitError("unimplemented lowering in " "createInitialValueForReduceOp"); @@ -124,8 +124,8 @@ static Value createReduceOpWithSingleRegionOp(Operation *op, Value input, if (!initValue) return nullptr; - stablehlo::ReduceOp reduce = rewriter.create( - op->getLoc(), outTy, input, initValue, + stablehlo::ReduceOp reduce = stablehlo::ReduceOp::create( + rewriter, op->getLoc(), outTy, input, initValue, rewriter.getDenseI64ArrayAttr(dims)); Block &block = reduce.getBody().emplaceBlock(); @@ -140,30 +140,30 @@ static Value createReduceOpWithSingleRegionOp(Operation *op, Value input, rewriter.setInsertionPointToStart(&block); Value result; if (isa(op)) { - result = rewriter.create( - op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); + result = stablehlo::MaxOp::create(rewriter, op->getLoc(), blockArgumentTy, + *firstArgument, *secondArgument); } else if (isa(op)) { - result = rewriter.create( - op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); + result = stablehlo::MinOp::create(rewriter, op->getLoc(), blockArgumentTy, + *firstArgument, *secondArgument); } else if (isa(op)) { - result = rewriter.create( - op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); + result = stablehlo::AddOp::create(rewriter, op->getLoc(), blockArgumentTy, + *firstArgument, *secondArgument); } else if (isa(op)) { - result = rewriter.create( - op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); + result = stablehlo::AndOp::create(rewriter, op->getLoc(), blockArgumentTy, + *firstArgument, *secondArgument); } else if (isa(op)) { - result = rewriter.create( - op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); + result = stablehlo::OrOp::create(rewriter, op->getLoc(), blockArgumentTy, + *firstArgument, *secondArgument); } else if (isa(op)) { - result = rewriter.create( - op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); + result = stablehlo::MulOp::create(rewriter, op->getLoc(), blockArgumentTy, + *firstArgument, *secondArgument); } else { op->emitError("unimplemented lowering in " "createReduceOpWithSingleRegionOp"); return nullptr; } - rewriter.create(op->getLoc(), result); + stablehlo::ReturnOp::create(rewriter, op->getLoc(), result); } return reduce.getResults()[0]; } @@ -198,16 +198,16 @@ createReduceOpReturnIndices(ConversionPatternRewriter &rewriter, Operation *op, auto outputIndexTy = RankedTensorType::get(outputShape, rewriter.getIntegerType(64)); - auto inputShapeTensor = rewriter.create( - op->getLoc(), inputShapeVec); - auto indexTensor = rewriter.create( - op->getLoc(), + auto inputShapeTensor = mlir::tensor::FromElementsOp::create( + rewriter, op->getLoc(), inputShapeVec); + auto indexTensor = stablehlo::DynamicIotaOp::create( + rewriter, op->getLoc(), RankedTensorType::get(inputShape, rewriter.getIntegerType(dimSizeIndexBits)), inputShapeTensor, static_cast(dim)); - auto stablehloReduceOp = rewriter.create( - op->getLoc(), TypeRange{outputTy, outputIndexTy}, + auto stablehloReduceOp = stablehlo::ReduceOp::create( + rewriter, op->getLoc(), TypeRange{outputTy, outputIndexTy}, ValueRange{input, indexTensor}, ValueRange{ initValue, @@ -258,33 +258,33 @@ createReduceOpReturnIndices(ConversionPatternRewriter &rewriter, Operation *op, Value compareResult; if (isa(op)) { - compareResult = rewriter.create( - op->getLoc(), compareResultType, *firstValArg, *secondValArg, - compareGeDirectionAttr, compareTypeAttr); + compareResult = stablehlo::CompareOp::create( + rewriter, op->getLoc(), compareResultType, *firstValArg, + *secondValArg, compareGeDirectionAttr, compareTypeAttr); } else if (isa(op)) { - compareResult = rewriter.create( - op->getLoc(), compareResultType, *firstValArg, *secondValArg, - compareLeDirectionAttr, compareTypeAttr); + compareResult = stablehlo::CompareOp::create( + rewriter, op->getLoc(), compareResultType, *firstValArg, + *secondValArg, compareLeDirectionAttr, compareTypeAttr); } else { op->emitError("unimplement lowering of createReduceOpReturnIndices"); return std::nullopt; } - Value retValResult = rewriter.create( - op->getLoc(), compareResult, *firstValArg, *secondValArg); + Value retValResult = stablehlo::SelectOp::create( + rewriter, op->getLoc(), compareResult, *firstValArg, *secondValArg); // get smaller index value if compared nums are equal. - Value compareEqResult = rewriter.create( - op->getLoc(), compareResultType, *firstValArg, *secondValArg, + Value compareEqResult = stablehlo::CompareOp::create( + rewriter, op->getLoc(), compareResultType, *firstValArg, *secondValArg, compareEqDirectionAttr, compareTypeAttr); - Value minIdx = rewriter.create(op->getLoc(), *firstIdxArg, - *secondIdxArg); - Value idxWithGeVal = rewriter.create( - op->getLoc(), compareResult, *firstIdxArg, *secondIdxArg); - Value retIdxResult = rewriter.create( - op->getLoc(), compareEqResult, minIdx, idxWithGeVal); - - rewriter.create( - op->getLoc(), ValueRange{retValResult, retIdxResult}); + Value minIdx = stablehlo::MinOp::create(rewriter, op->getLoc(), + *firstIdxArg, *secondIdxArg); + Value idxWithGeVal = stablehlo::SelectOp::create( + rewriter, op->getLoc(), compareResult, *firstIdxArg, *secondIdxArg); + Value retIdxResult = stablehlo::SelectOp::create( + rewriter, op->getLoc(), compareEqResult, minIdx, idxWithGeVal); + + stablehlo::ReturnOp::create(rewriter, op->getLoc(), + ValueRange{retValResult, retIdxResult}); } return stablehloReduceOp.getResults(); } @@ -295,15 +295,15 @@ static Value reshapeReduceResultWhenKeepDim(ConversionPatternRewriter &rewriter, Type outType, ArrayRef dims) { SmallVector outShapeVec(inputShapeVec); - Value one = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); + Value one = arith::ConstantOp::create( + rewriter, loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); for (auto dim : dims) { outShapeVec[dim] = one; } auto outShapeTensor = - rewriter.create(loc, outShapeVec); - return rewriter.create( - loc, outType, reduceResult, outShapeTensor); + tensor::FromElementsOp::create(rewriter, loc, outShapeVec); + return stablehlo::DynamicReshapeOp::create(rewriter, loc, outType, + reduceResult, outShapeTensor); } namespace { @@ -345,8 +345,8 @@ class ConvertAtenReduceAllDimsOp : public ConvertAtenReductionOp { } if (inputElemTy != outTy.getElementType()) { // use output type as computation type - input = rewriter.create(op->getLoc(), input, - outTy.getElementType()); + input = stablehlo::ConvertOp::create(rewriter, op->getLoc(), input, + outTy.getElementType()); } SmallVector dims = @@ -386,8 +386,8 @@ class ConvertAtenReduceOneDimOp : public ConvertAtenReductionOp { } if (inputElemTy != outTy.getElementType()) { // use output type as computation type - input = rewriter.create(op->getLoc(), input, - outTy.getElementType()); + input = stablehlo::ConvertOp::create(rewriter, op->getLoc(), input, + outTy.getElementType()); } bool keepDim = false; @@ -451,8 +451,8 @@ class ConvertAtenReduceDimsOp : public ConvertAtenReductionOp { } if (inputElemTy != outTy.getElementType()) { // use output type as computation type - input = rewriter.create(op->getLoc(), input, - outTy.getElementType()); + input = stablehlo::ConvertOp::create(rewriter, op->getLoc(), input, + outTy.getElementType()); } bool keepDim = false; @@ -611,7 +611,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( // Use output element type as computation type. auto dstElemTy = outTy.getElementType(); input = - rewriter.create(op->getLoc(), input, dstElemTy); + stablehlo::ConvertOp::create(rewriter, op->getLoc(), input, dstElemTy); inputTy = dyn_cast(input.getType()); } auto inputElemTy = inputTy.getElementType(); @@ -687,7 +687,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( // Use output element type as computation type. auto dstElemTy = outTy.getElementType(); input = - rewriter.create(op->getLoc(), input, dstElemTy); + stablehlo::ConvertOp::create(rewriter, op->getLoc(), input, dstElemTy); inputTy = dyn_cast(input.getType()); } auto inputElemTy = inputTy.getElementType(); @@ -766,7 +766,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( // Use output element type as computation type. auto dstElemTy = outTy.getElementType(); input = - rewriter.create(op->getLoc(), input, dstElemTy); + stablehlo::ConvertOp::create(rewriter, op->getLoc(), input, dstElemTy); inputTy = dyn_cast(input.getType()); } auto inputElemTy = inputTy.getElementType(); @@ -858,7 +858,8 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( op, "non-const bool `keepdim` is not supported"); } - auto squareOp = rewriter.create(op->getLoc(), input, input); + auto squareOp = + stablehlo::MulOp::create(rewriter, op->getLoc(), input, input); Value reduceResult = createReduceOpWithSingleRegionOp( op, squareOp.getResult(), @@ -867,7 +868,8 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( return op->emitError("createReduceOpWithSingleRegionOp return nullptr"); } - Value output = rewriter.create(op->getLoc(), reduceResult); + Value output = + stablehlo::SqrtOp::create(rewriter, op->getLoc(), reduceResult); if (keepDim) { auto outShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input); @@ -907,8 +909,8 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( } if (inputType.getElementType() != outType.getElementType()) { - input = - rewriter.create(op->getLoc(), input, outElemType); + input = stablehlo::ConvertOp::create(rewriter, op->getLoc(), input, + outElemType); } Value ord = @@ -944,9 +946,9 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( op, "non-const bool `keepdim` is not supported"); } - Value absValue = rewriter.create(op->getLoc(), input); - Value powValue = rewriter.create(op->getLoc(), absValue, - ord, nullptr); + Value absValue = stablehlo::AbsOp::create(rewriter, op->getLoc(), input); + Value powValue = chlo::BroadcastPowOp::create(rewriter, op->getLoc(), + absValue, ord, nullptr); Value reduceResult = createReduceOpWithSingleRegionOp( op, powValue, RankedTensorType::get(reduceResultShape, outElemType), dims, @@ -956,15 +958,15 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( } auto scalarType = RankedTensorType::get({}, outElemType); - auto constantOne = rewriter.create( - op->getLoc(), scalarType, + auto constantOne = stablehlo::ConstantOp::create( + rewriter, op->getLoc(), scalarType, DenseElementsAttr::get( scalarType, APFloat(cast(outElemType).getFloatSemantics(), 1))); - auto reciprocalOrd = rewriter.create( - op->getLoc(), scalarType, constantOne, ord); - Value output = rewriter.create( - op->getLoc(), reduceResult, reciprocalOrd, nullptr); + auto reciprocalOrd = stablehlo::DivOp::create(rewriter, op->getLoc(), + scalarType, constantOne, ord); + Value output = chlo::BroadcastPowOp::create( + rewriter, op->getLoc(), reduceResult, reciprocalOrd, nullptr); if (keepDim) { auto outShapeInfo = hlo::getDimIndexOfTensor(rewriter, op, input); diff --git a/lib/Conversion/TorchToStablehlo/Rng.cpp b/lib/Conversion/TorchToStablehlo/Rng.cpp index 340c5198bf11..b71af126c69e 100644 --- a/lib/Conversion/TorchToStablehlo/Rng.cpp +++ b/lib/Conversion/TorchToStablehlo/Rng.cpp @@ -39,8 +39,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (llvm::any_of(elements, [](int64_t dim) { return dim == ShapedType::kDynamic; })) return rewriter.notifyMatchFailure(op, "Dynamic shape support TBD"); - auto shape_tensor = rewriter.create( - loc, rewriter.getI64TensorAttr(elements)); + auto shape_tensor = stablehlo::ConstantOp::create( + rewriter, loc, rewriter.getI64TensorAttr(elements)); auto outTy = getTypeConverter()->convertType(op.getType()); auto outElemTy = cast(outTy).getElementType(); Value from = @@ -77,13 +77,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } auto scalarTy = RankedTensorType::get({}, outElemTy); - Value shapeTensor = rewriter.create( - loc, rewriter.getI64TensorAttr(shape)); - Value mean = rewriter.create( - loc, + Value shapeTensor = stablehlo::ConstantOp::create( + rewriter, loc, rewriter.getI64TensorAttr(shape)); + Value mean = stablehlo::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(scalarTy, rewriter.getFloatAttr(outElemTy, 0.0))); - Value var = rewriter.create( - loc, + Value var = stablehlo::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(scalarTy, rewriter.getFloatAttr(outElemTy, 1.0))); rewriter.replaceOpWithNewOp( @@ -108,8 +108,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (llvm::any_of(elements, [](int64_t dim) { return dim == ShapedType::kDynamic; })) return rewriter.notifyMatchFailure(op, "Dynamic shape support TBD"); - auto shapeTensor = rewriter.create( - loc, rewriter.getI64TensorAttr(elements)); + auto shapeTensor = stablehlo::ConstantOp::create( + rewriter, loc, rewriter.getI64TensorAttr(elements)); auto outTy = getTypeConverter()->convertType(op.getType()); auto outElemTy = cast(outTy).getElementType(); Value mean = diff --git a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp index b22dc3e6ed30..9a3798a93cc9 100644 --- a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp +++ b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp @@ -39,8 +39,8 @@ Value getConstantLike(OpBuilder &rewriter, Location loc, T constant, return mlir::complex::NumberAttr::get(complexTy, constant, 0); llvm_unreachable("unhandled element type"); }; - return rewriter.create( - loc, cast(getAttr()), val); + return mlir::chlo::ConstantLikeOp::create(rewriter, loc, + cast(getAttr()), val); } // Template instantiation @@ -56,8 +56,8 @@ Value getStablehloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op, auto const_type = RankedTensorType::get({}, rewriter.getF32Type()); auto const_attr = DenseElementsAttr::get(const_type, val); - auto const_op = rewriter.create( - op->getLoc(), const_type, const_attr); + auto const_op = stablehlo::ConstantOp::create(rewriter, op->getLoc(), + const_type, const_attr); return const_op.getResult(); } @@ -67,8 +67,8 @@ Value getStablehloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op, auto const_type = RankedTensorType::get({}, rewriter.getF64Type()); auto const_attr = DenseElementsAttr::get(const_type, val); - auto const_op = rewriter.create( - op->getLoc(), const_type, const_attr); + auto const_op = stablehlo::ConstantOp::create(rewriter, op->getLoc(), + const_type, const_attr); return const_op.getResult(); } @@ -102,8 +102,8 @@ std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, } auto const_attr = DenseElementsAttr::get(const_type, vec); - auto const_op = rewriter.create( - op->getLoc(), const_type, const_attr); + auto const_op = stablehlo::ConstantOp::create(rewriter, op->getLoc(), + const_type, const_attr); return const_op.getResult(); } @@ -153,27 +153,27 @@ Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op, T val, Type dtype, llvm::ArrayRef dshape) { auto const_type = RankedTensorType::get(dshape, dtype); auto const_attr = SplatElementsAttr::get(const_type, val); - auto const_op = rewriter.create( - op->getLoc(), const_type, const_attr); + auto const_op = stablehlo::ConstantOp::create(rewriter, op->getLoc(), + const_type, const_attr); return const_op.getResult(); } Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter, Operation *op, Value scalarValue, Type dtype) { - auto tensor = rewriter.create( - op->getLoc(), ArrayRef{scalarValue}); + auto tensor = tensor::FromElementsOp::create(rewriter, op->getLoc(), + ArrayRef{scalarValue}); auto dtype_tensor = - rewriter.create(op->getLoc(), tensor, dtype); - return rewriter.create( - op->getLoc(), RankedTensorType::get(mlir::ArrayRef{}, dtype), - dtype_tensor); + stablehlo::ConvertOp::create(rewriter, op->getLoc(), tensor, dtype); + return stablehlo::ReshapeOp::create( + rewriter, op->getLoc(), + RankedTensorType::get(mlir::ArrayRef{}, dtype), dtype_tensor); } Value promoteType(PatternRewriter &rewriter, Location loc, Value input, Type outElementType) { TensorType inType = cast(input.getType()); if (inType.getElementType() != outElementType) { - return rewriter.create(loc, input, outElementType); + return stablehlo::ConvertOp::create(rewriter, loc, input, outElementType); } return input; } @@ -194,8 +194,8 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input, if (in_type.getElementType() != outType.getElementType()) { TensorType promoted_type = in_type.cloneWith(in_type.getShape(), outType.getElementType()); - input = rewriter.create(op->getLoc(), promoted_type, - input); + input = stablehlo::ConvertOp::create(rewriter, op->getLoc(), promoted_type, + input); } ArrayRef inShape = in_type.getShape(); @@ -226,12 +226,13 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input, } auto bcast_attr = rewriter.getDenseI64ArrayAttr(bcastDims); if (bcastSizeTensor.has_value()) { - auto bcast_op = rewriter.create( - op->getLoc(), outType, input, bcastSizeTensor.value(), bcast_attr); + auto bcast_op = stablehlo::DynamicBroadcastInDimOp::create( + rewriter, op->getLoc(), outType, input, bcastSizeTensor.value(), + bcast_attr); return bcast_op.getResult(); } - auto bcast_op = rewriter.create( - op->getLoc(), outType, input, bcast_attr); + auto bcast_op = stablehlo::BroadcastInDimOp::create( + rewriter, op->getLoc(), outType, input, bcast_attr); return bcast_op.getResult(); } @@ -261,9 +262,9 @@ FailureOr> getDimSizesOfTensor(PatternRewriter &rewriter, auto loc = op->getLoc(); for (auto d : dims) { - dimSizes.emplace_back(rewriter.create( - loc, rewriter.getIntegerType(dimSizeIndexBits), - rewriter.create(loc, value, d))); + dimSizes.emplace_back(arith::IndexCastOp::create( + rewriter, loc, rewriter.getIntegerType(dimSizeIndexBits), + tensor::DimOp::create(rewriter, loc, value, d))); } return dimSizes; } @@ -301,7 +302,7 @@ getDimIndexOfTensor(PatternRewriter &rewriter, Operation *op, Value value, auto loc = op->getLoc(); for (auto d : dims) { - dimSizes.emplace_back(rewriter.create(loc, value, d)); + dimSizes.emplace_back(tensor::DimOp::create(rewriter, loc, value, d)); } return dimSizes; } @@ -342,8 +343,8 @@ getBroadcastResultShape(PatternRewriter &rewriter, Operation *op, int dynamicDimCnt = 0; int staticDimCnt = 0; int64_t dimSize = -1; - Value dimSizeTensor = rewriter.create( - op->getLoc(), + Value dimSizeTensor = mlir::arith::ConstantOp::create( + rewriter, op->getLoc(), rewriter.getIntegerAttr(rewriter.getIntegerType(dimSizeIndexBits), 1)); for (size_t i = 0; i < tensorSizes.size(); ++i) { // loop tensors. @@ -400,7 +401,7 @@ getBroadcastResultShape(PatternRewriter &rewriter, Operation *op, std::reverse(bcastSizes.begin(), bcastSizes.end()); std::reverse(bcastSizeTensors.begin(), bcastSizeTensors.end()); return std::pair>( - rewriter.create(op->getLoc(), bcastSizeTensors) + tensor::FromElementsOp::create(rewriter, op->getLoc(), bcastSizeTensors) .getResult(), bcastSizes); } @@ -432,8 +433,8 @@ FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, auto loc = op->getLoc(); auto rankTy = dyn_cast(tensor.getType()); auto oldShape = rankTy.getShape(); - auto one = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); + auto one = arith::ConstantOp::create( + rewriter, loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); std::vector newDimSizes; std::vector newShape; @@ -452,8 +453,9 @@ FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, } auto outTy = RankedTensorType::get(newShape, rankTy.getElementType()); - auto shape = rewriter.create(loc, newDimSizes); - return rewriter.create(loc, outTy, tensor, shape) + auto shape = tensor::FromElementsOp::create(rewriter, loc, newDimSizes); + return stablehlo::DynamicReshapeOp::create(rewriter, loc, outTy, tensor, + shape) .getResult(); } @@ -483,8 +485,8 @@ FailureOr collapseTensor(PatternRewriter &rewriter, Operation *op, newDimSizes.reserve(newRank); newShape.reserve(newRank); - Value collapseDimSize = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); + Value collapseDimSize = arith::ConstantOp::create( + rewriter, loc, rewriter.getIntegerAttr(rewriter.getIndexType(), 1)); int64_t collapseShape = 1; for (int64_t k = collapseStartDim; k <= collapseEndDim; ++k) { @@ -499,7 +501,7 @@ FailureOr collapseTensor(PatternRewriter &rewriter, Operation *op, collapseShape *= oldShape[k]; } collapseDimSize = - rewriter.create(loc, collapseDimSize, dimSizes[k]); + arith::MulIOp::create(rewriter, loc, collapseDimSize, dimSizes[k]); } for (int64_t k = 0; k < collapseStartDim; ++k) { @@ -514,8 +516,9 @@ FailureOr collapseTensor(PatternRewriter &rewriter, Operation *op, } auto outTy = RankedTensorType::get(newShape, rankTy.getElementType()); - auto shape = rewriter.create(loc, newDimSizes); - return rewriter.create(loc, outTy, tensor, shape) + auto shape = tensor::FromElementsOp::create(rewriter, loc, newDimSizes); + return stablehlo::DynamicReshapeOp::create(rewriter, loc, outTy, tensor, + shape) .getResult(); } @@ -542,11 +545,12 @@ FailureOr splitTensor(PatternRewriter &rewriter, Operation *op, } int64_t newRank = rank + 1; - auto outerLengthValue = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getIndexType(), outerLength)); + auto outerLengthValue = arith::ConstantOp::create( + rewriter, loc, + rewriter.getIntegerAttr(rewriter.getIndexType(), outerLength)); - auto innerLengthValue = rewriter.create( - loc, dimSizes[splitDim], outerLengthValue); + auto innerLengthValue = arith::DivSIOp::create( + rewriter, loc, dimSizes[splitDim], outerLengthValue); int64_t originShape = oldShape[splitDim]; int64_t outerShape = outerLength; @@ -575,8 +579,9 @@ FailureOr splitTensor(PatternRewriter &rewriter, Operation *op, } auto outTy = RankedTensorType::get(newShape, rankTy.getElementType()); - auto shape = rewriter.create(loc, newDimSizes); - return rewriter.create(loc, outTy, tensor, shape) + auto shape = tensor::FromElementsOp::create(rewriter, loc, newDimSizes); + return stablehlo::DynamicReshapeOp::create(rewriter, loc, outTy, tensor, + shape) .getResult(); } @@ -584,10 +589,10 @@ Value getConstantOfShape(PatternRewriter &rewriter, Location loc, const APFloat &constant, Value shape, TensorType outType) { auto constAttr = rewriter.getFloatAttr(outType.getElementType(), constant); - auto constTensor = rewriter.create(loc, constAttr); - return rewriter - .create( - loc, outType, constTensor, shape, rewriter.getDenseI64ArrayAttr({})) + auto constTensor = stablehlo::ConstantOp::create(rewriter, loc, constAttr); + return stablehlo::DynamicBroadcastInDimOp::create( + rewriter, loc, outType, constTensor, shape, + rewriter.getDenseI64ArrayAttr({})) .getResult(); } } // namespace hlo diff --git a/lib/Conversion/TorchToStablehlo/ViewLike.cpp b/lib/Conversion/TorchToStablehlo/ViewLike.cpp index 71b675b5ea2a..af48f84fc357 100644 --- a/lib/Conversion/TorchToStablehlo/ViewLike.cpp +++ b/lib/Conversion/TorchToStablehlo/ViewLike.cpp @@ -34,21 +34,21 @@ namespace { Value getNormalizedDimSizeInternal(PatternRewriter &rewriter, Operation *op, Value index, Value dimSize) { auto loc = op->getLoc(); - Value zero = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 0)); + Value zero = arith::ConstantOp::create( + rewriter, loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 0)); // To normalize index into range [-dimSize, dimSize] // index = min(max(-dimSize, index), dimSize) - auto negDimSize = rewriter.create(loc, zero, dimSize); - index = rewriter.create(loc, negDimSize, index); - index = rewriter.create(loc, dimSize, index); + auto negDimSize = arith::SubIOp::create(rewriter, loc, zero, dimSize); + index = arith::MaxSIOp::create(rewriter, loc, negDimSize, index); + index = arith::MinSIOp::create(rewriter, loc, dimSize, index); - auto dimSizePlusIndex = rewriter.create(loc, dimSize, index); - auto indexPositive = rewriter.create( - loc, arith::CmpIPredicate::sge, index, zero); + auto dimSizePlusIndex = arith::AddIOp::create(rewriter, loc, dimSize, index); + auto indexPositive = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sge, index, zero); // get positive index: (index >=0) ? index: index + dimSize - return rewriter.create(loc, indexPositive, index, - dimSizePlusIndex); + return arith::SelectOp::create(rewriter, loc, indexPositive, index, + dimSizePlusIndex); } Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op, @@ -59,10 +59,10 @@ Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op, auto loc = op->getLoc(); // startIndex & endIndex has been normailized into range [0, dSize] Type intType = rewriter.getIntegerType(dimSizeIndexBits); - Value zero = rewriter.create( - loc, rewriter.getIntegerAttr(intType, 0)); - Value one = rewriter.create( - loc, rewriter.getIntegerAttr(intType, 1)); + Value zero = arith::ConstantOp::create(rewriter, loc, + rewriter.getIntegerAttr(intType, 0)); + Value one = arith::ConstantOp::create(rewriter, loc, + rewriter.getIntegerAttr(intType, 1)); SmallVector startIndices; SmallVector endIndices; @@ -74,10 +74,10 @@ Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op, endIndices.reserve(rank); strides.reserve(rank); - auto endIndexIsZero = rewriter.create( - loc, arith::CmpIPredicate::eq, endIndex, zero); - endIndex = rewriter.create(loc, endIndexIsZero, - dimSizes[dimIndex], endIndex); + auto endIndexIsZero = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::eq, endIndex, zero); + endIndex = arith::SelectOp::create(rewriter, loc, endIndexIsZero, + dimSizes[dimIndex], endIndex); for (size_t r = 0; r < rank; ++r) { if (r == dimIndex) { @@ -92,14 +92,14 @@ Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op, } auto startTensor = - rewriter.create(loc, startIndices).getResult(); + tensor::FromElementsOp::create(rewriter, loc, startIndices).getResult(); auto endTensor = - rewriter.create(loc, endIndices).getResult(); + tensor::FromElementsOp::create(rewriter, loc, endIndices).getResult(); auto stridesTensor = - rewriter.create(loc, strides).getResult(); + tensor::FromElementsOp::create(rewriter, loc, strides).getResult(); - return rewriter.create( - loc, outTy, input, startTensor, endTensor, stridesTensor); + return stablehlo::RealDynamicSliceOp::create( + rewriter, loc, outTy, input, startTensor, endTensor, stridesTensor); } // Get a dynamic slice of the tensor from startIndex to endIndex with stride @@ -116,30 +116,32 @@ FailureOr getDynamicSlice(PatternRewriter &rewriter, Operation *op, auto rank = inputTy.getRank(); dim = (dim + rank) % rank; - Value dimSize = rewriter.create( - loc, rewriter.getI64Type(), - rewriter.create(loc, input, dim)); + Value dimSize = arith::IndexCastOp::create( + rewriter, loc, rewriter.getI64Type(), + tensor::DimOp::create(rewriter, loc, input, dim)); Value normStartIndex = startIndexOpt ? getNormalizedDimSizeInternal(rewriter, op, *startIndexOpt, dimSize) - : rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 0)); + : arith::ConstantOp::create( + rewriter, loc, + rewriter.getIntegerAttr(rewriter.getI64Type(), 0)); Value normEndIndex = endIndexOpt ? getNormalizedDimSizeInternal(rewriter, op, *endIndexOpt, dimSize) : dimSize; - Value step = - stepOpt ? *stepOpt - : rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getI64Type(), 1)); + Value step = stepOpt ? *stepOpt + : arith::ConstantOp::create( + rewriter, loc, + rewriter.getIntegerAttr(rewriter.getI64Type(), 1)); if (dimSizeIndexBits == 32) { Type intType = rewriter.getIntegerType(dimSizeIndexBits); normStartIndex = - rewriter.create(loc, intType, normStartIndex); - normEndIndex = rewriter.create(loc, intType, normEndIndex); - step = rewriter.create(loc, intType, step); + arith::TruncIOp::create(rewriter, loc, intType, normStartIndex); + normEndIndex = + arith::TruncIOp::create(rewriter, loc, intType, normEndIndex); + step = arith::TruncIOp::create(rewriter, loc, intType, step); } FailureOr> dimSizesInfo = hlo::getDimSizesOfTensor(rewriter, op, input, dimSizeIndexBits); @@ -212,14 +214,14 @@ class ConvertAtenViewOp : public ConvertAtenOp { auto castType = baseResultTy.getWithSizesAndDtype(castShape, baseResultTy.getDtype()); - auto cast = rewriter.create( - loc, + auto cast = stablehlo::BitcastConvertOp::create( + rewriter, loc, OpConversionPattern::getTypeConverter()->convertType( castType), self); auto reshape = - rewriter.create(loc, resultType, cast); + stablehlo::ReshapeOp::create(rewriter, loc, resultType, cast); rewriter.replaceOp(op, reshape); @@ -256,14 +258,15 @@ class ConvertAtenViewOp : public ConvertAtenOp { } std::for_each(dimSizes.begin(), dimSizes.end(), [&](Value &dSize) { - dSize = rewriter.create(loc, dSize).getResult(); + dSize = ToI64Op::create(rewriter, loc, dSize).getResult(); return dSize; }); - Value numel = rewriter.create( - loc, rewriter.create(loc, adaptor.getSelf())); + Value numel = shape::NumElementsOp::create( + rewriter, loc, + shape::ShapeOfOp::create(rewriter, loc, adaptor.getSelf())); numel = - rewriter.create(loc, rewriter.getI64Type(), numel); + arith::IndexCastOp::create(rewriter, loc, rewriter.getI64Type(), numel); // note: assuming that -1 doesn't arise from dynamic value if (negOneIndex.size() == 1) { @@ -271,7 +274,7 @@ class ConvertAtenViewOp : public ConvertAtenOp { Value realDim = numel; for (size_t i = 0; i < dimSizes.size(); i++) { if (i != index) { - realDim = rewriter.create(loc, realDim, dimSizes[i]); + realDim = arith::DivUIOp::create(rewriter, loc, realDim, dimSizes[i]); } } // update -1 to realDim @@ -279,7 +282,7 @@ class ConvertAtenViewOp : public ConvertAtenOp { } Value stablehloShape = - rewriter.create(loc, dimSizes); + tensor::FromElementsOp::create(rewriter, loc, dimSizes); rewriter.replaceOpWithNewOp( op, OpConversionPattern::getTypeConverter()->convertType( @@ -393,7 +396,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "failed to get dimension sizes of the input"); auto newDimSizes = *newDimSizesInfo; auto stablehloShape = - rewriter.create(op.getLoc(), newDimSizes); + tensor::FromElementsOp::create(rewriter, op.getLoc(), newDimSizes); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), self, stablehloShape); return success(); @@ -444,7 +447,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "failed to get dimension sizes of the input"); auto newDimSizes = *newDimSizesInfo; auto stablehloShape = - rewriter.create(op.getLoc(), newDimSizes); + tensor::FromElementsOp::create(rewriter, op.getLoc(), newDimSizes); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), self, stablehloShape); return success(); diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index 8796fd249e9e..60a04bbd7e55 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -93,7 +93,7 @@ convertTorchScatterIndexAndSrcToTMScatterIndexAndSrc(PatternRewriter &rewriter, Value indexSize = getTensorSize(rewriter, loc, indices); indexSize = castIntToIndex(rewriter, loc, indexSize); SmallVector indexShape = getTensorSizes(rewriter, loc, indices); - Value cstOne = rewriter.create(loc, 1); + Value cstOne = arith::ConstantIndexOp::create(rewriter, loc, 1); // We flatten the `src` values from (i, j, k, ...) -> (i * j * k * ...) SmallVector indSliceShape({indexSize, cstOne}); @@ -124,35 +124,35 @@ convertTorchScatterIndexAndSrcToTMScatterIndexAndSrc(PatternRewriter &rewriter, // on the current flattened index. The flattened iteration space is required // because TMTensorScatterOp expects a list of single element updates. auto flattenedUpdates = - rewriter - .create( - loc, outputsType, ValueRange(), outputs, mapping, iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - SmallVector indexValues(indexType.getRank()); - Value ind = b.create(loc, 0); - for (int i = indexType.getRank() - 1; i >= 0; i--) { - indexValues[i] = - b.create(loc, ind, indexShape[i]); - ind = b.create(loc, ind, indexShape[i]); - } - // Extract the scatter index and update value - Value extractIndexValue = - b.create(loc, indices, indexValues); - Value extractSrcValue = - b.create(loc, src, indexValues); - SmallVector yieldVals; - for (Value v : indexValues) { - Value scalar = castIndexToInt64(b, loc, v); - yieldVals.push_back(convertScalarToDtype( - rewriter, loc, scalar, indicesElemType)); - } - // Replace the original index with the index specified - // by the scatter. - yieldVals[dim] = convertScalarToDtype( - rewriter, loc, extractIndexValue, indicesElemType); - yieldVals.push_back(extractSrcValue); - b.create(loc, yieldVals); - }) + linalg::GenericOp::create( + rewriter, loc, outputsType, ValueRange(), outputs, mapping, + iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + SmallVector indexValues(indexType.getRank()); + Value ind = linalg::IndexOp::create(b, loc, 0); + for (int i = indexType.getRank() - 1; i >= 0; i--) { + indexValues[i] = + arith::RemSIOp::create(b, loc, ind, indexShape[i]); + ind = arith::DivSIOp::create(b, loc, ind, indexShape[i]); + } + // Extract the scatter index and update value + Value extractIndexValue = + tensor::ExtractOp::create(b, loc, indices, indexValues); + Value extractSrcValue = + tensor::ExtractOp::create(b, loc, src, indexValues); + SmallVector yieldVals; + for (Value v : indexValues) { + Value scalar = castIndexToInt64(b, loc, v); + yieldVals.push_back( + convertScalarToDtype(rewriter, loc, scalar, indicesElemType)); + } + // Replace the original index with the index specified + // by the scatter. + yieldVals[dim] = convertScalarToDtype( + rewriter, loc, extractIndexValue, indicesElemType); + yieldVals.push_back(extractSrcValue); + linalg::YieldOp::create(b, loc, yieldVals); + }) .getResultTensors(); auto toOpFoldResult = [](Value v) -> OpFoldResult { @@ -169,13 +169,13 @@ convertTorchScatterIndexAndSrcToTMScatterIndexAndSrc(PatternRewriter &rewriter, // new `src` tensor is the last tensor returned from the linalg::Generic // operation. SmallVector offsets = { - rewriter.create(loc, 0), - rewriter.create(loc, 0)}; + arith::ConstantIndexOp::create(rewriter, loc, 0), + arith::ConstantIndexOp::create(rewriter, loc, 0)}; SmallVector strides = { - rewriter.create(loc, 1), - rewriter.create(loc, 1)}; + arith::ConstantIndexOp::create(rewriter, loc, 1), + arith::ConstantIndexOp::create(rewriter, loc, 1)}; Value indicesRank = - rewriter.create(loc, indexType.getRank()); + arith::ConstantIndexOp::create(rewriter, loc, indexType.getRank()); Value flattenedIndices = createZeroInitTensor( rewriter, loc, SmallVector({indexSize, indicesRank}), indexType.getElementType()); @@ -216,8 +216,8 @@ static Value createTMTensorScatterOp( auto dimensionsMapAttr = b.getDenseI64ArrayAttr(dimensionsMap); auto originalTensorType = cast(original.getType()); Type originalElementType = originalTensorType.getElementType(); - auto scatterOp = b.create( - loc, originalTensorType, ValueRange{updates, indices}, + auto scatterOp = TMTensor::ScatterOp::create( + b, loc, originalTensorType, ValueRange{updates, indices}, ValueRange{original}, dimensionsMapAttr, uniqueIndices); Region &scatterOpRegion = scatterOp.getRegion(); @@ -238,8 +238,9 @@ static Value createTMTensorScanOp( function_ref bodyBuild) { auto inputType = cast(input.getType()); Type elementType = inputType.getElementType(); - auto scanOp = b.create( - loc, ValueRange{input}, ValueRange{output, accumulator}, dim, inclusive); + auto scanOp = + TMTensor::ScanOp::create(b, loc, ValueRange{input}, + ValueRange{output, accumulator}, dim, inclusive); Region &scanOpRegion = scanOp.getRegion(); auto &scanOpBlock = scanOpRegion.emplaceBlock(); @@ -270,7 +271,7 @@ static FailureOr createIntOrFloatCompareOp(PatternRewriter &rewriter, l = isEqual ? arith::CmpIPredicate::ule : arith::CmpIPredicate::ult; } arith::CmpIPredicate predicate = isDescending ? g : l; - compareOp = rewriter.create(loc, predicate, lhs, rhs); + compareOp = arith::CmpIOp::create(rewriter, loc, predicate, lhs, rhs); return compareOp; } @@ -282,7 +283,7 @@ static FailureOr createIntOrFloatCompareOp(PatternRewriter &rewriter, isEqual ? arith::CmpFPredicate::OLE : arith::CmpFPredicate::OLT; arith::CmpFPredicate predicate = isDescending ? g : l; - compareOp = rewriter.create(loc, predicate, lhs, rhs); + compareOp = arith::CmpFOp::create(rewriter, loc, predicate, lhs, rhs); return compareOp; } @@ -302,9 +303,9 @@ createTMTensorSortOp(PatternRewriter &rewriter, Location sortOpLoc, sortResultTypes.push_back(val.getType()); } ValueRange inputs; - auto sortOp = rewriter.create( - sortOpLoc, sortResultTypes, inputs, operands, - rewriter.getI64IntegerAttr(dimension)); + auto sortOp = + TMTensor::SortOp::create(rewriter, sortOpLoc, sortResultTypes, inputs, + operands, rewriter.getI64IntegerAttr(dimension)); // Step 2. Add two arguments for each element type in the SortOp's block. Region *body = &sortOp.getRegion(); @@ -325,7 +326,7 @@ createTMTensorSortOp(PatternRewriter &rewriter, Location sortOpLoc, loc, "Only Integer and Floating element type expected."); // Step 4. Create yield op for yielding the sorting predicate. - rewriter.create(loc, compareOpRetVal.value()); + TMTensor::YieldOp::create(rewriter, loc, compareOpRetVal.value()); return SmallVector(sortOp.getResults()); } @@ -341,9 +342,9 @@ static FailureOr> createTMTensorTopkOp( } // Create empty TopkOp, add body later. - auto topkOp = rewriter.create( - topkOpLoc, topkResultTypes, inputs, outputs, - rewriter.getI64IntegerAttr(dimension)); + auto topkOp = + TMTensor::TopkOp::create(rewriter, topkOpLoc, topkResultTypes, inputs, + outputs, rewriter.getI64IntegerAttr(dimension)); Region *body = &topkOp.getRegion(); Block *block = rewriter.createBlock(body); @@ -366,7 +367,7 @@ static FailureOr> createTMTensorTopkOp( loc, "Only Integer and Floating element type expected."); // Yield the comparison result. - rewriter.create(loc, compareOpRetVal.value()); + TMTensor::YieldOp::create(rewriter, loc, compareOpRetVal.value()); return SmallVector(topkOp.getResults()); } @@ -381,9 +382,9 @@ repeatTensorElementsForDim(Operation *op, ConversionPatternRewriter &rewriter, int64_t inputRank = selfTy.getSizes().size(); dim = toPositiveDim(dim, inputRank); Value dimValue = - rewriter.create(loc, rewriter.getI64IntegerAttr(dim)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(dim)); Value dimValuePlusOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(dim + 1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(dim + 1)); auto unsqueezedInfo = unsqueezeTensor(rewriter, op, self, dimValuePlusOne); if (failed(unsqueezedInfo)) @@ -392,12 +393,13 @@ repeatTensorElementsForDim(Operation *op, ConversionPatternRewriter &rewriter, self = *unsqueezedInfo; Value constMinusOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(-1)); SmallVector expandShapeValueList(inputRank + 1, constMinusOne); expandShapeValueList[dim + 1] = - rewriter.create(loc, rewriter.getI64IntegerAttr(repeats)); - Value expandShapeList = rewriter.create( - loc, ListType::get(IntType::get(context)), expandShapeValueList); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(repeats)); + Value expandShapeList = PrimListConstructOp::create( + rewriter, loc, ListType::get(IntType::get(context)), + expandShapeValueList); SmallVector expandShape(inputRank + 1); for (int64_t i = 0; i <= dim; i++) { @@ -411,10 +413,10 @@ repeatTensorElementsForDim(Operation *op, ConversionPatternRewriter &rewriter, BaseTensorType expandTy = rewriter.getType(expandShape, selfTy.getOptionalDtype()); Value expandSelf = - rewriter.create(loc, expandTy, self, expandShapeList); + AtenBroadcastToOp::create(rewriter, loc, expandTy, self, expandShapeList); - Value result = rewriter.create(loc, resType, expandSelf, - dimValue, dimValuePlusOne); + Value result = PrimsCollapseOp::create(rewriter, loc, resType, expandSelf, + dimValue, dimValuePlusOne); return result; } @@ -460,16 +462,16 @@ class ConvertAtenScatterOp : public OpConversionPattern { [&](OpBuilder &b, Location loc, Value updatesElement, Value inputElement) { if (isa(op)) { - b.create(loc, updatesElement); + TMTensor::YieldOp::create(b, loc, updatesElement); } else if (isa(op)) { if (isa(selfType.getElementType())) { Value add = - b.create(loc, inputElement, updatesElement); - b.create(loc, add); + arith::AddIOp::create(b, loc, inputElement, updatesElement); + TMTensor::YieldOp::create(b, loc, add); } else if (isa(selfType.getElementType())) { Value add = - b.create(loc, inputElement, updatesElement); - b.create(loc, add); + arith::AddFOp::create(b, loc, inputElement, updatesElement); + TMTensor::YieldOp::create(b, loc, add); } } }); @@ -530,14 +532,14 @@ class ConvertAtenBincountOp : public OpConversionPattern { context, llvm::ArrayRef(maxTensorSizes), cast(torchTypeInput.getType()).getDtype()); Value maxTensor = - rewriter.create(loc, maxTensorType, torchTypeInput); + AtenMaxOp::create(rewriter, loc, maxTensorType, torchTypeInput); maxTensor = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(maxTensor.getType()), maxTensor); // `maxTensor` is a 0-d tensor, extracting its only element and // storing it in `maxInput`. - Value maxInput = rewriter.create(loc, maxTensor); + Value maxInput = tensor::ExtractOp::create(rewriter, loc, maxTensor); // Creating a tm_tensor.scatter op with the following mapping: // 1.) `input` tensor maps to the indices in scatter op. `input` is @@ -551,10 +553,10 @@ class ConvertAtenBincountOp : public OpConversionPattern { ValueTensorType expandInputType = ValueTensorType::get( context, llvm::ArrayRef(expandedInputSizes), cast(torchTypeInput.getType()).getDtype()); - Value torchCstOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - Value expandedInputTensor = rewriter.create( - loc, expandInputType, torchTypeInput, torchCstOne); + Value torchCstOne = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(1)); + Value expandedInputTensor = AtenUnsqueezeOp::create( + rewriter, loc, expandInputType, torchTypeInput, torchCstOne); Value indices = typeConverter->materializeTargetConversion( rewriter, loc, @@ -567,19 +569,19 @@ class ConvertAtenBincountOp : public OpConversionPattern { SmallVector inputSizeDynamic = getTensorSizesUntilDim(rewriter, loc, input, 0); - Value updatesTensor = rewriter.create( - loc, getAsOpFoldResult(inputSizeDynamic), resultElemType); + Value updatesTensor = tensor::EmptyOp::create( + rewriter, loc, getAsOpFoldResult(inputSizeDynamic), resultElemType); - Value constantZero = rewriter.create( - loc, rewriter.getZeroAttr(resultElemType)); - Value constantOne = rewriter.create( - loc, 1, resultElemType.getIntOrFloatBitWidth()); + Value constantZero = arith::ConstantOp::create( + rewriter, loc, rewriter.getZeroAttr(resultElemType)); + Value constantOne = arith::ConstantIntOp::create( + rewriter, loc, 1, resultElemType.getIntOrFloatBitWidth()); // Bincount size = max(max(input) + 1, minlength) Value maxInputPlusOne = - rewriter.create(loc, maxInput, constantOne); + arith::AddIOp::create(rewriter, loc, maxInput, constantOne); Value bincountSize = - rewriter.create(loc, maxInputPlusOne, minlength); + arith::MaxSIOp::create(rewriter, loc, maxInputPlusOne, minlength); bincountSize = castIntToIndex(rewriter, loc, bincountSize); Value bincountTensor = createInitTensor(rewriter, loc, {bincountSize}, resultElemType, constantZero); @@ -588,8 +590,8 @@ class ConvertAtenBincountOp : public OpConversionPattern { rewriter, loc, updatesTensor, indices, bincountTensor, /*dimensionsMap=*/createDefaultDimMap(indices), /*uniqueIndices=*/false, [&](OpBuilder &b, Location loc, Value _, Value bincountElem) { - Value add = b.create(loc, bincountElem, constantOne); - b.create(loc, add); + Value add = arith::AddIOp::create(b, loc, bincountElem, constantOne); + TMTensor::YieldOp::create(b, loc, add); }); rewriter.replaceOpWithNewOp(op, resultType, scatterOp); return success(); @@ -616,7 +618,7 @@ getBroadcastShape(Location loc, llvm::ArrayRef indices, OpBuilder b) { }; Value torchCstOne = - b.create(loc, b.getI64IntegerAttr(1)); + Torch::ConstantIntOp::create(b, loc, b.getI64IntegerAttr(1)); llvm::SmallVector broadcastSizes(indicesRank, torchCstOne); llvm::SmallVector broadcastShape(indicesRank, 0); for (auto index : indices) { @@ -625,13 +627,13 @@ getBroadcastShape(Location loc, llvm::ArrayRef indices, OpBuilder b) { int32_t rank = shape.size(); for (int32_t j = 0; j < rank; ++j) { - Value dim = b.create(loc, b.getI64IntegerAttr(j)); - auto sizeOp = b.create(loc, index, dim); + Value dim = Torch::ConstantIntOp::create(b, loc, b.getI64IntegerAttr(j)); + auto sizeOp = Torch::AtenSizeIntOp::create(b, loc, index, dim); auto size = shape[j]; int32_t idx = broadcastShape.size() - rank + j; broadcastSizes[idx] = - b.create(loc, sizeOp, broadcastSizes[idx]); + Torch::PrimMaxIntOp::create(b, loc, sizeOp, broadcastSizes[idx]); broadcastShape[idx] = maxDim(size, broadcastShape[idx]); } } @@ -643,11 +645,11 @@ Value combinePutIndices(Location loc, llvm::ArrayRef indicesRef, llvm::SmallVector indices(indicesRef); // Declare commonly used constants up front: Value torchCstZero = - b.create(loc, b.getI64IntegerAttr(0)); + Torch::ConstantIntOp::create(b, loc, b.getI64IntegerAttr(0)); Value torchCstOne = - b.create(loc, b.getI64IntegerAttr(1)); + Torch::ConstantIntOp::create(b, loc, b.getI64IntegerAttr(1)); Value torchCstNegOne = - b.create(loc, b.getI64IntegerAttr(-1)); + Torch::ConstantIntOp::create(b, loc, b.getI64IntegerAttr(-1)); auto [broadcastSizes, broadcastShape] = getBroadcastShape(loc, indicesRef, b); @@ -663,19 +665,20 @@ Value combinePutIndices(Location loc, llvm::ArrayRef indicesRef, } // Broadcast together and flatten to batch values: - Value broadcastSizeList = b.create( - loc, Torch::ListType::get(b.getType()), broadcastSizes); + Value broadcastSizeList = PrimListConstructOp::create( + b, loc, Torch::ListType::get(b.getType()), + broadcastSizes); for (Value &index : indices) { auto indexTy = cast(index.getType()); auto expandTy = b.getType( broadcastShape, indexTy.getOptionalDtype()); - index = b.create(loc, expandTy, index, - broadcastSizeList); + index = Torch::AtenBroadcastToOp::create(b, loc, expandTy, index, + broadcastSizeList); auto flattenTy = b.getType( scatterBatchCount, indexTy.getOptionalDtype()); - index = b.create( - loc, flattenTy, index, torchCstZero, torchCstNegOne); + index = Torch::AtenFlattenUsingIntsOp::create(b, loc, flattenTy, index, + torchCstZero, torchCstNegOne); } // Unsqueeze so we have a 1 dim to concat along: @@ -689,20 +692,20 @@ Value combinePutIndices(Location loc, llvm::ArrayRef indicesRef, auto unsqueezeTy = b.getType(shape, btt.getDtype()); Value unsqueezed = - b.create(loc, unsqueezeTy, tensor, torchCstOne); + AtenUnsqueezeOp::create(b, loc, unsqueezeTy, tensor, torchCstOne); tensor = unsqueezed; } BaseTensorType unsqueezedTensorType = cast(indices[0].getType()); - Value indicesTorchList = b.create( - loc, Torch::ListType::get(unsqueezedTensorType), indices); + Value indicesTorchList = PrimListConstructOp::create( + b, loc, Torch::ListType::get(unsqueezedTensorType), indices); llvm::SmallVector concatShape{ unsqueezedTensorType.getSizes()[0], static_cast(indices.size())}; ValueTensorType concatIndicesType = b.getType( llvm::ArrayRef(concatShape), unsqueezedTensorType.getDtype()); - return b.create(loc, concatIndicesType, indicesTorchList, - torchCstOne); + return AtenCatOp::create(b, loc, concatIndicesType, indicesTorchList, + torchCstOne); } // Helper that collapses the batch dimensions together and moves it to the front @@ -720,14 +723,14 @@ static Value collapseAndMoveBatchDims(Location loc, Value values, int64_t batch, // We need a length-1 dim at the start to transpose the batch to: if (batch != 0) { - outDims.push_back(b.create(loc, 1)); + outDims.push_back(Torch::ConstantIntOp::create(b, loc, 1)); outShape.push_back(1); } // Dimensions before the batch stay the same: for (int i = 0; i <= batch; i++) { - auto k = b.create(loc, b.getI64IntegerAttr(i)); - auto dim = b.create(loc, values, k); + auto k = Torch::ConstantIntOp::create(b, loc, b.getI64IntegerAttr(i)); + auto dim = Torch::AtenSizeIntOp::create(b, loc, values, k); outDims.push_back(dim); outShape.push_back(inShape[i]); } @@ -743,25 +746,25 @@ static Value collapseAndMoveBatchDims(Location loc, Value values, int64_t batch, outShape.back() = mulI(outShape.back(), inShape[batch + i]); auto k = - b.create(loc, b.getI64IntegerAttr(batch + i)); - auto dim = b.create(loc, values, k); - outDims.back() = b.create(loc, dim, outDims.back()); + Torch::ConstantIntOp::create(b, loc, b.getI64IntegerAttr(batch + i)); + auto dim = Torch::AtenSizeIntOp::create(b, loc, values, k); + outDims.back() = Torch::AtenMulIntOp::create(b, loc, dim, outDims.back()); } // Add the dimensions after the batch dims: for (int i = batch + count, s = inShape.size(); i < s; ++i) { - auto k = b.create(loc, b.getI64IntegerAttr(i)); - auto dim = b.create(loc, values, k); + auto k = Torch::ConstantIntOp::create(b, loc, b.getI64IntegerAttr(i)); + auto dim = Torch::AtenSizeIntOp::create(b, loc, values, k); outDims.push_back(dim); outShape.push_back(inShape[i]); } - Value outDimsList = b.create( - loc, Torch::ListType::get(b.getType()), outDims); + Value outDimsList = PrimListConstructOp::create( + b, loc, Torch::ListType::get(b.getType()), outDims); valuesTy = b.getType(outShape, valuesTy.getOptionalDtype()); - values = b.create(loc, valuesTy, values, outDimsList); + values = AtenViewOp::create(b, loc, valuesTy, values, outDimsList); if (batch == 0) return values; @@ -770,14 +773,14 @@ static Value collapseAndMoveBatchDims(Location loc, Value values, int64_t batch, std::swap(outDims[0], outDims[batch + 1]); std::swap(outShape[0], outShape[batch + 1]); - Value dim0 = b.create(loc, b.getI64IntegerAttr(0)); + Value dim0 = Torch::ConstantIntOp::create(b, loc, b.getI64IntegerAttr(0)); Value dimB = - b.create(loc, b.getI64IntegerAttr(batch + 1)); + Torch::ConstantIntOp::create(b, loc, b.getI64IntegerAttr(batch + 1)); valuesTy = b.getType(outShape, valuesTy.getOptionalDtype()); values = - b.create(loc, valuesTy, values, dim0, dimB); + Torch::AtenTransposeIntOp::create(b, loc, valuesTy, values, dim0, dimB); outDims.clear(); outShape.clear(); @@ -786,16 +789,16 @@ static Value collapseAndMoveBatchDims(Location loc, Value values, int64_t batch, for (int i = 0; i < transposeRank; ++i) { if (i == batch + 1) continue; - Value k = b.create(loc, b.getI64IntegerAttr(i)); - outDims.push_back(b.create(loc, values, k)); + Value k = Torch::ConstantIntOp::create(b, loc, b.getI64IntegerAttr(i)); + outDims.push_back(AtenSizeIntOp::create(b, loc, values, k)); outShape.push_back(transposeShape[i]); } valuesTy = b.getType(outShape, valuesTy.getOptionalDtype()); - outDimsList = b.create( - loc, Torch::ListType::get(b.getType()), outDims); - return b.create(loc, valuesTy, values, outDimsList); + outDimsList = PrimListConstructOp::create( + b, loc, Torch::ListType::get(b.getType()), outDims); + return AtenViewOp::create(b, loc, valuesTy, values, outDimsList); } // Broadcast the `values` tensor to the slice size created by the list of index @@ -813,17 +816,17 @@ static Value broadcastValuesToSliceSize(Location loc, Value input, Value values, // of the indexed slice. auto [resultShape, resultStaticShape] = getBroadcastShape(loc, indices, b); for (size_t i = indices.size(); i < inputStaticShape.size(); i++) { - Value dim = b.create(loc, b.getI64IntegerAttr(i)); - resultShape.push_back(b.create(loc, input, dim)); + Value dim = Torch::ConstantIntOp::create(b, loc, b.getI64IntegerAttr(i)); + resultShape.push_back(AtenSizeIntOp::create(b, loc, input, dim)); resultStaticShape.push_back(inputStaticShape[i]); } auto resultType = b.getType( resultStaticShape, valuesType.getOptionalDtype()); - Value broadcastShapeList = b.create( - loc, Torch::ListType::get(b.getType()), resultShape); - return b.create(loc, resultType, values, - broadcastShapeList); + Value broadcastShapeList = PrimListConstructOp::create( + b, loc, Torch::ListType::get(b.getType()), resultShape); + return AtenBroadcastToOp::create(b, loc, resultType, values, + broadcastShapeList); } class ConvertAtenIndexPutHackedTwinOp @@ -915,10 +918,10 @@ class ConvertAtenIndexPutHackedTwinOp valuesType = cast(values.getType()); // Materialize out the length-1 dimensions: - Value zero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - Value one = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); + Value zero = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(0)); + Value one = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(1)); llvm::SmallVector valuesShape; llvm::SmallVector valuesDims; int vDim = 0; @@ -927,7 +930,7 @@ class ConvertAtenIndexPutHackedTwinOp inputType.getSizes().size()) { valuesShape.push_back(valuesType.getSizes().front()); valuesDims.push_back( - rewriter.create(loc, values, zero)); + Torch::AtenSizeIntOp::create(rewriter, loc, values, zero)); vDim++; } @@ -939,22 +942,22 @@ class ConvertAtenIndexPutHackedTwinOp continue; } - Value k = rewriter.create( - loc, rewriter.getI64IntegerAttr(vDim)); + Value k = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(vDim)); valuesDims.push_back( - rewriter.create(loc, values, k)); + Torch::AtenSizeIntOp::create(rewriter, loc, values, k)); valuesShape.push_back(inputType.getSizes()[i]); vDim++; } - Value valuesDimsList = rewriter.create( - loc, Torch::ListType::get(rewriter.getType()), + Value valuesDimsList = PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(rewriter.getType()), valuesDims); valuesType = rewriter.getType( valuesShape, valuesType.getOptionalDtype()); values = - rewriter.create(loc, valuesType, values, valuesDimsList); + AtenViewOp::create(rewriter, loc, valuesType, values, valuesDimsList); input = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(input.getType()), input); @@ -981,16 +984,16 @@ class ConvertAtenIndexPutHackedTwinOp if (accumulate) { if (isa(inputElement.getType())) { yieldValue = - b.create(loc, inputElement, valuesElement); + arith::AddIOp::create(b, loc, inputElement, valuesElement); } else if (isa(inputElement.getType())) { yieldValue = - b.create(loc, inputElement, valuesElement); + arith::AddFOp::create(b, loc, inputElement, valuesElement); } else { invalidInputTypeFound = true; return; } } - b.create(loc, yieldValue); + TMTensor::YieldOp::create(b, loc, yieldValue); }); if (invalidInputTypeFound) { @@ -1110,48 +1113,46 @@ class ConvertAtenMaxPool2dWithIndicesBackwardOp getAsOpFoldResult(getTensorSizes(rewriter, loc, indices)); updatedIndicesShape.push_back(rewriter.getIndexAttr(tensorOperandRank)); - Value initTensor = rewriter.create( - loc, updatedIndicesShape, indicesElemType); + Value initTensor = tensor::EmptyOp::create( + rewriter, loc, updatedIndicesShape, indicesElemType); Value wIn = inputShape[tensorOperandRank - 1]; SmallVector cstValues; for (int64_t i = 0; i < tensorOperandRank; i++) - cstValues.push_back(rewriter.create(loc, i)); + cstValues.push_back(arith::ConstantIndexOp::create(rewriter, loc, i)); Value updatedIndices = - rewriter - .create( - loc, initTensor.getType(), indices, initTensor, indexingMaps, - iteratorTypes, - [tensorOperandRank, wIn, cstValues, - indicesElemType](OpBuilder &b, Location loc, ValueRange args) { - Value index = castIntToIndex(b, loc, args[0]); - Value updatedIndex = cstValues[0]; - Value lastDim = - b.create(loc, tensorOperandRank); - - for (int64_t i = tensorOperandRank - 1; i >= 0; i--) { - Value result; - if (i == tensorOperandRank - 1) - result = b.create(loc, index, wIn); - if (i == tensorOperandRank - 2) - result = b.create(loc, index, wIn); - if (i == tensorOperandRank - 3 || - i == tensorOperandRank - 4) - result = b.create(loc, i); - - Value pred = b.create( - loc, arith::CmpIPredicate::eq, lastDim, cstValues[i]); - Value addAmount = b.create( - loc, pred, result, cstValues[0]); - updatedIndex = - b.create(loc, updatedIndex, addAmount); - } - - updatedIndex = b.create( - loc, indicesElemType, updatedIndex); - b.create(loc, updatedIndex); - }) + linalg::GenericOp::create( + rewriter, loc, initTensor.getType(), indices, initTensor, + indexingMaps, iteratorTypes, + [tensorOperandRank, wIn, cstValues, + indicesElemType](OpBuilder &b, Location loc, ValueRange args) { + Value index = castIntToIndex(b, loc, args[0]); + Value updatedIndex = cstValues[0]; + Value lastDim = + linalg::IndexOp::create(b, loc, tensorOperandRank); + + for (int64_t i = tensorOperandRank - 1; i >= 0; i--) { + Value result; + if (i == tensorOperandRank - 1) + result = arith::RemSIOp::create(b, loc, index, wIn); + if (i == tensorOperandRank - 2) + result = arith::FloorDivSIOp::create(b, loc, index, wIn); + if (i == tensorOperandRank - 3 || i == tensorOperandRank - 4) + result = linalg::IndexOp::create(b, loc, i); + + Value pred = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::eq, lastDim, cstValues[i]); + Value addAmount = + arith::SelectOp::create(b, loc, pred, result, cstValues[0]); + updatedIndex = + arith::AddIOp::create(b, loc, updatedIndex, addAmount); + } + + updatedIndex = arith::IndexCastOp::create(b, loc, indicesElemType, + updatedIndex); + linalg::YieldOp::create(b, loc, updatedIndex); + }) .getResult(0); // Creating a new tensor initialized with zeros and size same as the input @@ -1167,8 +1168,9 @@ class ConvertAtenMaxPool2dWithIndicesBackwardOp int64_t numelGradOutput = getNumberOfElements(gradOutputType); gradOutputFlattenedType = RankedTensorType::get( makeShapeLLVMCompatible({numelGradOutput}), gradOutputElemType); - Value gradOutputFlattened = rewriter.create( - loc, gradOutputFlattenedType, gradOutput, reassociationCollapse); + Value gradOutputFlattened = + tensor::CollapseShapeOp::create(rewriter, loc, gradOutputFlattenedType, + gradOutput, reassociationCollapse); // Collapsing updated indices into a 2-d tensor. SmallVector reassociationCollapseIndices(2); @@ -1176,8 +1178,8 @@ class ConvertAtenMaxPool2dWithIndicesBackwardOp reassociationCollapseIndices[0].push_back(i); reassociationCollapseIndices[1].push_back(tensorOperandRank); int64_t numelIndices = getNumberOfElements(indicesType); - Value indicesCollapsed = rewriter.create( - loc, + Value indicesCollapsed = tensor::CollapseShapeOp::create( + rewriter, loc, RankedTensorType::get( makeShapeLLVMCompatible({numelIndices, tensorOperandRank}), indicesElemType), @@ -1194,15 +1196,15 @@ class ConvertAtenMaxPool2dWithIndicesBackwardOp Value yieldValue = valuesElement; if (isa(inputElement.getType())) { yieldValue = - b.create(loc, inputElement, valuesElement); + arith::AddIOp::create(b, loc, inputElement, valuesElement); } else if (isa(inputElement.getType())) { yieldValue = - b.create(loc, inputElement, valuesElement); + arith::AddFOp::create(b, loc, inputElement, valuesElement); } else { invalidInputTypeFound = true; return; } - b.create(loc, yieldValue); + TMTensor::YieldOp::create(b, loc, yieldValue); }); if (invalidInputTypeFound) { @@ -1283,7 +1285,7 @@ class ConvertAtenScatterReduceTwoOp } else { llvm_unreachable("Only integer/float types supported!"); } - Value initElement = rewriter.create(loc, initAttr); + Value initElement = arith::ConstantOp::create(rewriter, loc, initAttr); counts = createInitTensor(rewriter, loc, selfShape, selfType.getElementType(), initElement); } @@ -1295,16 +1297,18 @@ class ConvertAtenScatterReduceTwoOp if (reduceEnum == torch_upstream::ReductionType::SUM || reduceEnum == torch_upstream::ReductionType::MEAN) { // Set the values in the input tensor to '0' so they are not included - normalizationValue = rewriter.create( - loc, rewriter.getZeroAttr(srcType.getElementType())); + normalizationValue = arith::ConstantOp::create( + rewriter, loc, rewriter.getZeroAttr(srcType.getElementType())); } else if (reduceEnum == torch_upstream::ReductionType::PROD) { // Set the values in the input tensor to '1' (multiplication identity) if (llvm::isa(srcType.getElementType())) { - normalizationValue = rewriter.create( - loc, rewriter.getFloatAttr(srcType.getElementType(), 1.0)); + normalizationValue = arith::ConstantOp::create( + rewriter, loc, + rewriter.getFloatAttr(srcType.getElementType(), 1.0)); } else if (llvm::isa(srcType.getElementType())) { - normalizationValue = rewriter.create( - loc, rewriter.getIntegerAttr(srcType.getElementType(), 1)); + normalizationValue = arith::ConstantOp::create( + rewriter, loc, + rewriter.getIntegerAttr(srcType.getElementType(), 1)); } else { llvm_unreachable("Only integer/float types supported!"); } @@ -1313,13 +1317,13 @@ class ConvertAtenScatterReduceTwoOp // type TypedAttr minAttr = getNumericLimit(rewriter, srcType.getElementType(), /*getMin=*/true); - normalizationValue = rewriter.create(loc, minAttr); + normalizationValue = arith::ConstantOp::create(rewriter, loc, minAttr); } else if (reduceEnum == torch_upstream::ReductionType::MIN) { // Set the values in the input tensor to the largest element of that // type TypedAttr maxAttr = getNumericLimit(rewriter, srcType.getElementType(), /*getMin=*/false); - normalizationValue = rewriter.create(loc, maxAttr); + normalizationValue = arith::ConstantOp::create(rewriter, loc, maxAttr); } // Scatter the normalizations into the input tensor @@ -1333,7 +1337,7 @@ class ConvertAtenScatterReduceTwoOp /*dimensionsMap=*/createDefaultDimMap(indices), /*uniqueIndices=*/false, [&](OpBuilder &b, Location loc, Value update, Value current) { - b.create(loc, update); + TMTensor::YieldOp::create(b, loc, update); }); if (reduceEnum == torch_upstream::ReductionType::MEAN) { counts = createTMTensorScatterOp( @@ -1341,7 +1345,7 @@ class ConvertAtenScatterReduceTwoOp /*dimensionsMap=*/createDefaultDimMap(indices), /*uniqueIndices=*/false, [&](OpBuilder &b, Location loc, Value update, Value current) { - b.create(loc, update); + TMTensor::YieldOp::create(b, loc, update); }); } } @@ -1355,38 +1359,38 @@ class ConvertAtenScatterReduceTwoOp if (reduceEnum == torch_upstream::ReductionType::SUM || reduceEnum == torch_upstream::ReductionType::MEAN) { if (isa(update.getType())) { - result = b.create(loc, update, current); + result = arith::AddIOp::create(b, loc, update, current); } else if (isa(update.getType())) { - result = b.create(loc, update, current); + result = arith::AddFOp::create(b, loc, update, current); } else { llvm_unreachable("Only integer/float types supported!"); } } else if (reduceEnum == torch_upstream::ReductionType::PROD) { if (isa(update.getType())) { - result = b.create(loc, update, current); + result = arith::MulIOp::create(b, loc, update, current); } else if (isa(update.getType())) { - result = b.create(loc, update, current); + result = arith::MulFOp::create(b, loc, update, current); } else { llvm_unreachable("Only integer/float types supported!"); } } else if (reduceEnum == torch_upstream::ReductionType::MAX) { if (isa(update.getType())) { - result = b.create(loc, update, current); + result = arith::MaxSIOp::create(b, loc, update, current); } else if (isa(update.getType())) { - result = b.create(loc, update, current); + result = arith::MaximumFOp::create(b, loc, update, current); } else { llvm_unreachable("Only integer/float types supported!"); } } else if (reduceEnum == torch_upstream::ReductionType::MIN) { if (isa(update.getType())) { - result = b.create(loc, update, current); + result = arith::MinSIOp::create(b, loc, update, current); } else if (isa(update.getType())) { - result = b.create(loc, update, current); + result = arith::MinimumFOp::create(b, loc, update, current); } else { llvm_unreachable("Only integer/float types supported!"); } } - b.create(loc, result); + TMTensor::YieldOp::create(b, loc, result); }); // Special case for the mean @@ -1399,40 +1403,39 @@ class ConvertAtenScatterReduceTwoOp Value result; if (mlir::IntegerType intType = llvm::dyn_cast(current.getType())) { - Value constantUpdate = b.create( - loc, b.getIntegerAttr(intType, 1)); - result = b.create(loc, constantUpdate, current); + Value constantUpdate = arith::ConstantOp::create( + b, loc, b.getIntegerAttr(intType, 1)); + result = arith::AddIOp::create(b, loc, constantUpdate, current); } else if (mlir::FloatType floatType = llvm::dyn_cast(current.getType())) { - Value constantUpdate = b.create( - loc, b.getFloatAttr(floatType, 1.0)); - result = b.create(loc, constantUpdate, current); + Value constantUpdate = arith::ConstantOp::create( + b, loc, b.getFloatAttr(floatType, 1.0)); + result = arith::AddFOp::create(b, loc, constantUpdate, current); } else { llvm_unreachable("Only integer/float types supported!"); } - b.create(loc, result); + TMTensor::YieldOp::create(b, loc, result); }); - Value output = rewriter.create( - loc, tensor::getMixedSizes(rewriter, loc, self), + Value output = tensor::EmptyOp::create( + rewriter, loc, tensor::getMixedSizes(rewriter, loc, self), selfType.getElementType()); // Finally divide the result scatterOp = - rewriter - .create( - loc, ValueRange{scatterOp, counts}, output, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value result; - if (llvm::isa(args[0].getType())) { - result = b.create(loc, args[0], args[1]); - } else if (llvm::isa(args[0].getType())) { - result = b.create(loc, args[0], args[1]); - } else { - llvm_unreachable("Only integer/float types supported!"); - } - b.create(loc, result); - }) + linalg::MapOp::create( + rewriter, loc, ValueRange{scatterOp, counts}, output, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value result; + if (llvm::isa(args[0].getType())) { + result = arith::DivSIOp::create(b, loc, args[0], args[1]); + } else if (llvm::isa(args[0].getType())) { + result = arith::DivFOp::create(b, loc, args[0], args[1]); + } else { + llvm_unreachable("Only integer/float types supported!"); + } + linalg::YieldOp::create(b, loc, result); + }) .getResult()[0]; } auto resultType = cast( @@ -1484,26 +1487,25 @@ class ConvertAtenSortOp : public OpConversionPattern { SmallVector dynDims; for (unsigned i = 0; i < inputType.getRank(); i++) { if (inputType.isDynamicDim(i)) { - dynDims.push_back(rewriter.create(loc, inputTensor, i)); + dynDims.push_back(tensor::DimOp::create(rewriter, loc, inputTensor, i)); } } - Value initEmptyTensor = rewriter.create( - loc, inputType.getShape(), rewriter.getI64Type(), dynDims); + Value initEmptyTensor = tensor::EmptyOp::create( + rewriter, loc, inputType.getShape(), rewriter.getI64Type(), dynDims); SmallVector indexingMaps = { AffineMap::getMultiDimIdentityMap(inputRank, op.getContext())}; SmallVector iteratorTypes( inputRank, utils::IteratorType::parallel); Value indicesTensor = - rewriter - .create( - loc, initEmptyTensor.getType(), ValueRange{}, initEmptyTensor, - indexingMaps, iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value index = b.create(loc, dim); - index = castIndexToInt64(b, loc, index); - b.create(loc, index); - }) + linalg::GenericOp::create( + rewriter, loc, initEmptyTensor.getType(), ValueRange{}, + initEmptyTensor, indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value index = linalg::IndexOp::create(b, loc, dim); + index = castIndexToInt64(b, loc, index); + linalg::YieldOp::create(b, loc, index); + }) .getResult(0); // Step 6. Create TMTensor::SortOp. @@ -1576,7 +1578,7 @@ class ConvertAtenCumprodOp : public OpConversionPattern { SmallVector sizes = getTensorSizes(rewriter, loc, input); Value output = createOneInitTensor(rewriter, loc, sizes, elementType); - output = rewriter.create(loc, resultType, output); + output = tensor::CastOp::create(rewriter, loc, resultType, output); SmallVector accSizes(sizes); accSizes.erase(accSizes.begin() + dim); @@ -1586,16 +1588,16 @@ class ConvertAtenCumprodOp : public OpConversionPattern { Value acc = createOneInitTensor(rewriter, loc, accSizes, elementType); Type accType = RankedTensorType::get(makeShapeLLVMCompatible(accStatic), elementType); - acc = rewriter.create(loc, accType, acc); + acc = tensor::CastOp::create(rewriter, loc, accType, acc); Value result = createTMTensorScanOp( rewriter, loc, input, output, acc, dim, /*inclusive=*/true, [](OpBuilder &b, Location loc, Value input, Value acc) { Value prod = (isa(input.getType()) - ? b.create(loc, input, acc)->getResult(0) - : b.create(loc, input, acc)->getResult(0)); - b.create(loc, prod); + ? arith::MulFOp::create(b, loc, input, acc)->getResult(0) + : arith::MulIOp::create(b, loc, input, acc)->getResult(0)); + TMTensor::YieldOp::create(b, loc, prod); }); rewriter.replaceOpWithNewOp(op, resultType, result); @@ -1665,7 +1667,7 @@ class ConvertAtenCumsumOp : public OpConversionPattern { SmallVector sizes = getTensorSizes(rewriter, loc, input); Value output = createZeroInitTensor(rewriter, loc, sizes, elementType); - output = rewriter.create(loc, resultType, output); + output = tensor::CastOp::create(rewriter, loc, resultType, output); SmallVector accSizes(sizes); accSizes.erase(accSizes.begin() + dim); @@ -1675,16 +1677,16 @@ class ConvertAtenCumsumOp : public OpConversionPattern { Value acc = createZeroInitTensor(rewriter, loc, accSizes, elementType); Type accType = RankedTensorType::get(makeShapeLLVMCompatible(accStatic), elementType); - acc = rewriter.create(loc, accType, acc); + acc = tensor::CastOp::create(rewriter, loc, accType, acc); Value result = createTMTensorScanOp( rewriter, loc, input, output, acc, dim, /*inclusive=*/true, [](OpBuilder &b, Location loc, Value input, Value acc) { Value sum = (isa(input.getType()) - ? b.create(loc, input, acc)->getResult(0) - : b.create(loc, input, acc)->getResult(0)); - b.create(loc, sum); + ? arith::AddFOp::create(b, loc, input, acc)->getResult(0) + : arith::AddIOp::create(b, loc, input, acc)->getResult(0)); + TMTensor::YieldOp::create(b, loc, sum); }); rewriter.replaceOpWithNewOp(op, resultType, result); @@ -1795,41 +1797,43 @@ class ConvertAtenScaledDotProductAttentionOp for (int i = 0, s = queryTy.getRank() - 1; i < s; ++i) { maskStatic.push_back(queryTy.getDimSize(i)); if (maskStatic.back() == ShapedType::kDynamic) - maskDyn.push_back(rewriter.create(loc, query, i)); + maskDyn.push_back(tensor::DimOp::create(rewriter, loc, query, i)); } maskStatic.push_back(keyTy.getDimSize(keyTy.getRank() - 2)); if (maskStatic.back() == ShapedType::kDynamic) maskDyn.push_back( - rewriter.create(loc, key, keyTy.getRank() - 2)); + tensor::DimOp::create(rewriter, loc, key, keyTy.getRank() - 2)); Type maskType = getElementTypeOrSelf(queryTy); Value emptyMask = - rewriter.create(loc, maskStatic, maskType, maskDyn); + tensor::EmptyOp::create(rewriter, loc, maskStatic, maskType, maskDyn); - Value zero = rewriter.create( - loc, rewriter.getFloatAttr(getElementTypeOrSelf(maskType), 0.0)); - Value negInf = rewriter.create( - loc, + Value zero = arith::ConstantOp::create( + rewriter, loc, + rewriter.getFloatAttr(getElementTypeOrSelf(maskType), 0.0)); + Value negInf = arith::ConstantOp::create( + rewriter, loc, rewriter.getFloatAttr(getElementTypeOrSelf(maskType), -INFINITY)); - mask = rewriter.create(loc, zero, emptyMask).getResult(0); + mask = + linalg::FillOp::create(rewriter, loc, zero, emptyMask).getResult(0); int64_t rank = cast(queryTy).getRank(); AffineMap maskMap = rewriter.getMultiDimIdentityMap(rank); SmallVector iteratorTypes( rank, utils::IteratorType::parallel); - auto genericOp = rewriter.create( - loc, mask.getType(), ValueRange{}, mask, + auto genericOp = linalg::GenericOp::create( + rewriter, loc, mask.getType(), ValueRange{}, mask, SmallVector{maskMap}, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { - Value i = b.create(loc, queryTy.getRank() - 2); - Value j = b.create(loc, queryTy.getRank() - 1); + Value i = linalg::IndexOp::create(b, loc, queryTy.getRank() - 2); + Value j = linalg::IndexOp::create(b, loc, queryTy.getRank() - 1); Value cond = - b.create(loc, arith::CmpIPredicate::sge, i, j); - Value select = b.create(loc, cond, zero, negInf); - b.create(loc, select); + arith::CmpIOp::create(b, loc, arith::CmpIPredicate::sge, i, j); + Value select = arith::SelectOp::create(b, loc, cond, zero, negInf); + linalg::YieldOp::create(b, loc, select); }); mask = genericOp.getResult(0); } @@ -1859,7 +1863,7 @@ class ConvertAtenScaledDotProductAttentionOp if (queryTy.isDynamicDim(i)) { maskDynDims.push_back( - rewriter.create(loc, query, i)); + tensor::DimOp::create(rewriter, loc, query, i)); } } @@ -1869,10 +1873,10 @@ class ConvertAtenScaledDotProductAttentionOp maskShape.push_back(maskTy.getDimSize(rank - 1)); if (maskTy.isDynamicDim(rank - 2)) maskDynDims.push_back( - rewriter.create(loc, mask, rank - 2)); + tensor::DimOp::create(rewriter, loc, mask, rank - 2)); if (maskTy.isDynamicDim(rank - 1)) maskDynDims.push_back( - rewriter.create(loc, mask, rank - 1)); + tensor::DimOp::create(rewriter, loc, mask, rank - 1)); SmallVector affineMaps = { AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, maskExprs, @@ -1881,16 +1885,15 @@ class ConvertAtenScaledDotProductAttentionOp SmallVector findMaxIteratorTypes( rank, utils::IteratorType::parallel); - Value emptyMask = rewriter.create( - loc, maskShape, maskTy.getElementType(), maskDynDims); + Value emptyMask = tensor::EmptyOp::create( + rewriter, loc, maskShape, maskTy.getElementType(), maskDynDims); Value newMask = - rewriter - .create( - loc, emptyMask.getType(), mask, ValueRange({emptyMask}), - affineMaps, findMaxIteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - b.create(loc, args[0]); - }) + linalg::GenericOp::create( + rewriter, loc, emptyMask.getType(), mask, + ValueRange({emptyMask}), affineMaps, findMaxIteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + linalg::YieldOp::create(b, loc, args[0]); + }) .getResult(0); mask = newMask; } @@ -1951,8 +1954,8 @@ class ConvertAtenScaledDotProductAttentionOp } auto collapseTy = valueTy.clone(newShape); - return rewriter.create(loc, collapseTy, value, - reassociation); + return tensor::CollapseShapeOp::create(rewriter, loc, collapseTy, value, + reassociation); }; query = collapseBatch(query); @@ -1980,14 +1983,13 @@ class ConvertAtenScaledDotProductAttentionOp } // Overwrite with tm_tensor::attention - Value attention = rewriter - .create(loc, outType, inputs, - SmallVector{output}) + Value attention = AttentionOp::create(rewriter, loc, outType, inputs, + SmallVector{output}) .getResult()[0]; if (opTy != outType) { - attention = rewriter.create(loc, opTy, attention, - reassociation); + attention = tensor::ExpandShapeOp::create(rewriter, loc, opTy, attention, + reassociation); } rewriter.replaceOp(op, attention); @@ -2068,16 +2070,16 @@ class ConvertAtenKthvalueOp : public OpConversionPattern { Value fillValTopK; if (isa(inputElementType)) { // max float for topk tensor - fillValTopK = rewriter.create( - loc, + fillValTopK = arith::ConstantOp::create( + rewriter, loc, rewriter.getFloatAttr( inputElementType, APFloat::getInf( cast(inputElementType).getFloatSemantics(), /*Negative=*/false))); // min float for linalg generic op tensor - fillValLinalgFindMax = rewriter.create( - loc, + fillValLinalgFindMax = arith::ConstantOp::create( + rewriter, loc, rewriter.getFloatAttr( inputElementType, APFloat::getInf( @@ -2087,22 +2089,22 @@ class ConvertAtenKthvalueOp : public OpConversionPattern { auto width = cast(inputElementType).getWidth(); // max signed int for topk op tensor auto init = APSInt::getSignedMaxValue(width); - fillValTopK = rewriter.create( - loc, rewriter.getIntegerAttr(inputElementType, init)); + fillValTopK = arith::ConstantOp::create( + rewriter, loc, rewriter.getIntegerAttr(inputElementType, init)); // min signed int for linalg generic op tensor init = APSInt::getSignedMinValue(width); - fillValLinalgFindMax = rewriter.create( - loc, rewriter.getIntegerAttr(inputElementType, init)); + fillValLinalgFindMax = arith::ConstantOp::create( + rewriter, loc, rewriter.getIntegerAttr(inputElementType, init)); } else if (isUnsigned) { auto width = cast(inputElementType).getWidth(); // max unsigned int for topk op tensor auto init = APInt::getMaxValue(width); - fillValTopK = rewriter.create( - loc, rewriter.getIntegerAttr(inputElementType, init)); + fillValTopK = arith::ConstantOp::create( + rewriter, loc, rewriter.getIntegerAttr(inputElementType, init)); // min unsigned int for linalg generic op tensor init = APInt::getMinValue(width); - fillValLinalgFindMax = rewriter.create( - loc, rewriter.getIntegerAttr(inputElementType, init)); + fillValLinalgFindMax = arith::ConstantOp::create( + rewriter, loc, rewriter.getIntegerAttr(inputElementType, init)); } auto i32Type = rewriter.getI32Type(); @@ -2116,11 +2118,11 @@ class ConvertAtenKthvalueOp : public OpConversionPattern { // except topkShape[dim] = k. SmallVector topkShape; for (unsigned i = 0; i < inputRank; i++) { - auto currentDimSize = rewriter.create(loc, input, i); + auto currentDimSize = tensor::DimOp::create(rewriter, loc, input, i); topkShape.push_back(currentDimSize); } - auto dimSize = rewriter.create( - loc, rewriter.getIntegerAttr(rewriter.getI64Type(), k)); + auto dimSize = arith::ConstantOp::create( + rewriter, loc, rewriter.getIntegerAttr(rewriter.getI64Type(), k)); topkShape[dim] = dimSize; // Fill the initial topk op output tensor. @@ -2132,7 +2134,7 @@ class ConvertAtenKthvalueOp : public OpConversionPattern { auto signlessType = mlir::IntegerType::get(op.getContext(), 32, mlir::IntegerType::Signless); auto initIdx = getNumericLimit(rewriter, signlessType, /*getMin=*/false); - auto fillValTopkIdx = rewriter.create(loc, initIdx); + auto fillValTopkIdx = arith::ConstantOp::create(rewriter, loc, initIdx); // Fill the initial topk op output indices tensor. Value topkOutputIdx = createInitTensor(rewriter, loc, topkShape, i32Type, fillValTopkIdx); @@ -2181,7 +2183,7 @@ class ConvertAtenKthvalueOp : public OpConversionPattern { SmallVector resultShapeInt; for (int64_t i = 0; i < inputType.getRank(); i++) { if (dim != i) { - auto currentDimSize = rewriter.create(loc, input, i); + auto currentDimSize = tensor::DimOp::create(rewriter, loc, input, i); resultShape.push_back(currentDimSize); resultShapeInt.push_back(inputType.getShape()[i]); } @@ -2216,8 +2218,8 @@ class ConvertAtenKthvalueOp : public OpConversionPattern { rewriter.getContext()); // Create linalg op for finding the max value in the extracted topk values. - auto findMaxLinalg = rewriter.create( - loc, + auto findMaxLinalg = linalg::GenericOp::create( + rewriter, loc, ArrayRef( {findMaxOutputVal.getType(), findMaxOutputIdx.getType()}), topkOpVal.front(), ValueRange({findMaxOutputVal, findMaxOutputIdx}), @@ -2231,34 +2233,35 @@ class ConvertAtenKthvalueOp : public OpConversionPattern { Value oldValue = blockArgs[1]; Value oldIndex = blockArgs[2]; - Value newIndex = rewriter.create( - nestedLoc, oldIndex.getType(), - rewriter.create(nestedLoc, dim)); + Value newIndex = arith::IndexCastOp::create( + rewriter, nestedLoc, oldIndex.getType(), + linalg::IndexOp::create(rewriter, nestedLoc, dim)); Value resultVal, predicate; if (isa(inputElementType)) { - resultVal = rewriter.create(nestedLoc, newValue, - oldValue); - predicate = rewriter.create( - nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue); + resultVal = arith::MaximumFOp::create(rewriter, nestedLoc, newValue, + oldValue); + predicate = arith::CmpFOp::create(rewriter, nestedLoc, + arith::CmpFPredicate::OGT, + newValue, oldValue); } else { arith::CmpIPredicate predType; predType = isUnsigned ? arith::CmpIPredicate::ugt : arith::CmpIPredicate::sgt; if (isUnsigned) { - resultVal = rewriter.create(nestedLoc, newValue, - oldValue); + resultVal = arith::MaxUIOp::create(rewriter, nestedLoc, newValue, + oldValue); } else { - resultVal = rewriter.create(nestedLoc, newValue, - oldValue); + resultVal = arith::MaxSIOp::create(rewriter, nestedLoc, newValue, + oldValue); } - predicate = rewriter.create(nestedLoc, predType, - newValue, oldValue); + predicate = arith::CmpIOp::create(rewriter, nestedLoc, predType, + newValue, oldValue); } - auto resultIndex = rewriter.create( - nestedLoc, predicate, newIndex, oldIndex); - nestedBuilder.create( - nestedLoc, ValueRange{resultVal, resultIndex}); + auto resultIndex = arith::SelectOp::create( + rewriter, nestedLoc, predicate, newIndex, oldIndex); + linalg::YieldOp::create(nestedBuilder, nestedLoc, + ValueRange{resultVal, resultIndex}); }); auto findMaxVal = findMaxLinalg.getResult(0); @@ -2301,28 +2304,29 @@ class ConvertAtenKthvalueOp : public OpConversionPattern { // Linalg generic op for indexing the topk output idx tensor using // the idx tensor returned by the linalg generic op for finding max. // Only the idx tensor from the linalg generic op is sent as input. - auto extractedIdxLinalg = rewriter.create( - loc, ArrayRef({filledTensorExtractedIdx.getType()}), findMaxIdx, - filledTensorExtractedIdx, extractedIdxMaps, extractedIdxIteratorTypes, + auto extractedIdxLinalg = linalg::GenericOp::create( + rewriter, loc, ArrayRef({filledTensorExtractedIdx.getType()}), + findMaxIdx, filledTensorExtractedIdx, extractedIdxMaps, + extractedIdxIteratorTypes, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { // Get the current input idx. - Value index = rewriter.create( - loc, rewriter.getIndexType(), blockArgs[0]); + Value index = arith::IndexCastOp::create( + rewriter, loc, rewriter.getIndexType(), blockArgs[0]); // Create idx to index the topk idx tensor. // Index the dim dimension using the current input idx. SmallVector indexTarget; for (unsigned i = 0; i < dim; i++) - indexTarget.push_back(rewriter.create(loc, i)); + indexTarget.push_back(linalg::IndexOp::create(rewriter, loc, i)); indexTarget.push_back(index); for (unsigned i = dim; i < findMaxIdxType.getRank(); i++) - indexTarget.push_back(rewriter.create(loc, i)); + indexTarget.push_back(linalg::IndexOp::create(rewriter, loc, i)); // Extract the element from the topk idx tensor. - Value extractedElement = rewriter.create( - loc, topkOpVal.back(), indexTarget); - rewriter.create(loc, extractedElement); + Value extractedElement = tensor::ExtractOp::create( + rewriter, loc, topkOpVal.back(), indexTarget); + linalg::YieldOp::create(rewriter, loc, extractedElement); }); auto extractedIdx = extractedIdxLinalg.getResult(0); @@ -2352,21 +2356,22 @@ class ConvertAtenKthvalueOp : public OpConversionPattern { // Linalg generic op for casting topk idx output tensor elements from i32 to // result idx tensor element type. - auto castedIdxLinalg = rewriter.create( - loc, ArrayRef({filledTensorCastedIdx.getType()}), extractedIdx, - filledTensorCastedIdx, castedIdxMaps, castedIdxIteratorTypes, + auto castedIdxLinalg = linalg::GenericOp::create( + rewriter, loc, ArrayRef({filledTensorCastedIdx.getType()}), + extractedIdx, filledTensorCastedIdx, castedIdxMaps, + castedIdxIteratorTypes, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { Value oldIdx = blockArgs[0]; // Cast from i32 to index. - Value oldIdxToIndexType = rewriter.create( - nestedLoc, rewriter.getIndexType(), oldIdx); + Value oldIdxToIndexType = arith::IndexCastOp::create( + rewriter, nestedLoc, rewriter.getIndexType(), oldIdx); // Cast from index to result idx element type. - Value resultIdx = rewriter.create( - nestedLoc, idxResultElementType, oldIdxToIndexType); + Value resultIdx = arith::IndexCastOp::create( + rewriter, nestedLoc, idxResultElementType, oldIdxToIndexType); - nestedBuilder.create(nestedLoc, resultIdx); + linalg::YieldOp::create(nestedBuilder, nestedLoc, resultIdx); }); auto castedIdx = castedIdxLinalg.getResult(0); @@ -2388,9 +2393,9 @@ class ConvertAtenKthvalueOp : public OpConversionPattern { if (!keepDim) { // If keepdim=false, cast the the outputs to appropriate type and return. Value retVal = - rewriter.create(loc, squeezedValType, findMaxVal); + tensor::CastOp::create(rewriter, loc, squeezedValType, findMaxVal); Value retIdx = - rewriter.create(loc, squeezedIdxType, castedIdx); + tensor::CastOp::create(rewriter, loc, squeezedIdxType, castedIdx); llvm::SmallVector res{retVal, retIdx}; rewriter.replaceOp(op, res); return success(); @@ -2409,10 +2414,11 @@ class ConvertAtenKthvalueOp : public OpConversionPattern { valShape.resize(valShape.size() - 1); idxShape.resize(idxShape.size() - 1); - Value retVal = rewriter.create( - loc, squeezedValType.clone(valShape), findMaxLinalg.getResult(0)); - Value retIdx = rewriter.create( - loc, squeezedIdxType.clone(idxShape), castedIdx); + Value retVal = + tensor::CastOp::create(rewriter, loc, squeezedValType.clone(valShape), + findMaxLinalg.getResult(0)); + Value retIdx = tensor::CastOp::create( + rewriter, loc, squeezedIdxType.clone(idxShape), castedIdx); SmallVector reassociation(valShape.size()); if (reassociation.size() > 0) { @@ -2433,11 +2439,11 @@ class ConvertAtenKthvalueOp : public OpConversionPattern { valShape[dim] = 1; idxShape[dim] = 1; - Value unsqueezeVal = rewriter.create( - loc, valResultType, retVal, reassociation); + Value unsqueezeVal = tensor::ExpandShapeOp::create( + rewriter, loc, valResultType, retVal, reassociation); - Value unsqueezeIdx = rewriter.create( - loc, idxResultType, retIdx, reassociation); + Value unsqueezeIdx = tensor::ExpandShapeOp::create( + rewriter, loc, idxResultType, retIdx, reassociation); // Return unsqueezed. llvm::SmallVector unsqueezes = {unsqueezeVal, unsqueezeIdx}; diff --git a/lib/Conversion/TorchToTensor/TorchToTensor.cpp b/lib/Conversion/TorchToTensor/TorchToTensor.cpp index 76b9b87cbfe9..10fd2a160d0d 100644 --- a/lib/Conversion/TorchToTensor/TorchToTensor.cpp +++ b/lib/Conversion/TorchToTensor/TorchToTensor.cpp @@ -38,26 +38,26 @@ class ConvertAtenItemOp : public OpConversionPattern { if (operandTy.getNumElements() != 1) return rewriter.notifyMatchFailure(op, "expected only one item"); - auto zeroIdx = rewriter.create(op.getLoc(), 0); + auto zeroIdx = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0); auto rank = operandTy.getRank(); llvm::SmallVector indices(rank, zeroIdx); - Value extract = rewriter.create( - op.getLoc(), operandTy.getElementType(), operand, indices); + Value extract = tensor::ExtractOp::create( + rewriter, op.getLoc(), operandTy.getElementType(), operand, indices); auto extractTy = extract.getType(); if (isa(extractTy) && !extractTy.isInteger(64)) { if (torchDTy.isUnsignedInteger()) { - extract = rewriter.create( - op.getLoc(), rewriter.getIntegerType(64), extract); + extract = arith::ExtUIOp::create(rewriter, op.getLoc(), + rewriter.getIntegerType(64), extract); } else { - extract = rewriter.create( - op.getLoc(), rewriter.getIntegerType(64), extract); + extract = arith::ExtSIOp::create(rewriter, op.getLoc(), + rewriter.getIntegerType(64), extract); } } if (isa(extractTy) && !extractTy.isF64()) { - extract = rewriter.create(op.getLoc(), - rewriter.getF64Type(), extract); + extract = arith::ExtFOp::create(rewriter, op.getLoc(), + rewriter.getF64Type(), extract); } rewriter.replaceOp(op, extract); @@ -124,10 +124,10 @@ class ConvertAtenTensorOpPattern : public OpConversionPattern { operand); if (isa(resultETy) && value.getType() != resultETy) - value = rewriter.create(loc, resultETy, value); + value = arith::TruncIOp::create(rewriter, loc, resultETy, value); if (isa(resultETy) && value.getType() != resultETy) - value = rewriter.create(loc, resultETy, value); + value = arith::TruncFOp::create(rewriter, loc, resultETy, value); values.push_back(value); } diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 52a8ce73cee1..c959f06c6a66 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -128,16 +128,16 @@ class ConvertAtenBinaryOp : public OpConversionPattern { if constexpr (std::is_same()) { // TOSA ArithmeticRightShiftOp has a round parameter. - binaryOp = rewriter.create(op->getLoc(), outTy, lhs, rhs, - /*round=*/false); + binaryOp = TosaOpT::create(rewriter, op->getLoc(), outTy, lhs, rhs, + /*round=*/false); } else if constexpr (std::is_same() || std::is_same()) { lhs = tosa::tosaCastTensorToType(rewriter, lhs, outTy).value(); rhs = tosa::tosaCastTensorToType(rewriter, rhs, outTy).value(); // Use default NaN Propagation mode "PROPAGATE" for tosa.maximum and // tosa.minimum - binaryOp = rewriter.create( - op->getLoc(), outTy, lhs, rhs, + binaryOp = TosaOpT::create( + rewriter, op->getLoc(), outTy, lhs, rhs, /*nan_mode=*/ tosa::NanPropagationModeAttr::get( rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE)); @@ -500,9 +500,9 @@ class ConvertAtenCompareOp : public OpConversionPattern { } } - auto resultOp = rewriter.create(op.getLoc(), resultTy, - (swapLhsRhs ? rhsTensor : lhs), - (swapLhsRhs ? lhs : rhsTensor)); + auto resultOp = TosaOpT::create(rewriter, op.getLoc(), resultTy, + (swapLhsRhs ? rhsTensor : lhs), + (swapLhsRhs ? lhs : rhsTensor)); // There is no NE operator in TOSA. if constexpr (std::is_same() || @@ -615,19 +615,19 @@ std::optional truncFloatDivWithDivResult(PatternRewriter &rewriter, .failed()) return std::nullopt; - auto cond = rewriter.create( - op->getLoc(), + auto cond = tosa::GreaterEqualOp::create( + rewriter, op->getLoc(), RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(1)), divResult, zero); - auto selectOp = rewriter.create(op->getLoc(), outType, cond, - one, minusOne); + auto selectOp = tosa::SelectOp::create(rewriter, op->getLoc(), outType, cond, + one, minusOne); auto absDivResult = - rewriter.create(op->getLoc(), outType, divResult); + tosa::AbsOp::create(rewriter, op->getLoc(), outType, divResult); auto flooredAbsDivResult = - rewriter.create(op->getLoc(), outType, absDivResult); + tosa::FloorOp::create(rewriter, op->getLoc(), outType, absDivResult); Value result = tosa::createMulOpAndCast(rewriter, op, outType, flooredAbsDivResult, @@ -644,7 +644,7 @@ Value truncFloatDiv(PatternRewriter &rewriter, Operation *op, rhs = tosa::tosaCastTensorToType(rewriter, rhs, outType).value(); auto rhsRcp = - rewriter.create(op->getLoc(), rhs.getType(), rhs); + tosa::ReciprocalOp::create(rewriter, op->getLoc(), rhs.getType(), rhs); auto divResult = tosa::createMulOpAndCast(rewriter, op, outType, lhs, rhsRcp, /*shift=*/0); @@ -671,7 +671,7 @@ std::optional floorIntDiv(PatternRewriter &rewriter, Operation *op, rhs = tosa::tosaCastTensorToType(rewriter, rhs, i32Type).value(); auto intDivOp = - rewriter.create(op->getLoc(), i32Type, lhs, rhs); + tosa::IntDivOp::create(rewriter, op->getLoc(), i32Type, lhs, rhs); auto zero = tosa::getConstTensor(rewriter, op, 0, {}).value(); @@ -687,26 +687,27 @@ std::optional floorIntDiv(PatternRewriter &rewriter, Operation *op, auto lhsMulRhs = tosa::createMulOpAndCast(rewriter, op, i32Type, lhs, rhs, /*shift=*/0); - auto lhsRhsDifferentSign = - rewriter.create(op->getLoc(), boolType, zero, lhsMulRhs); + auto lhsRhsDifferentSign = tosa::GreaterOp::create(rewriter, op->getLoc(), + boolType, zero, lhsMulRhs); auto truncMulRhs = tosa::createMulOpAndCast(rewriter, op, i32Type, intDivOp, rhs, /*shift=*/0); auto truncMulRhsEqualLhs = - rewriter.create(op->getLoc(), boolType, truncMulRhs, lhs); + tosa::EqualOp::create(rewriter, op->getLoc(), boolType, truncMulRhs, lhs); - auto truncMulRhsNotEqualLhs = rewriter.create( - op->getLoc(), boolType, truncMulRhsEqualLhs); + auto truncMulRhsNotEqualLhs = tosa::LogicalNotOp::create( + rewriter, op->getLoc(), boolType, truncMulRhsEqualLhs); auto truncMinusOne = - rewriter.create(op->getLoc(), i32Type, intDivOp, one); + tosa::SubOp::create(rewriter, op->getLoc(), i32Type, intDivOp, one); - auto cond = rewriter.create( - op->getLoc(), boolType, lhsRhsDifferentSign, truncMulRhsNotEqualLhs); + auto cond = + tosa::LogicalAndOp::create(rewriter, op->getLoc(), boolType, + lhsRhsDifferentSign, truncMulRhsNotEqualLhs); - auto selectOp = rewriter.create(op->getLoc(), i32Type, cond, - truncMinusOne, intDivOp); + auto selectOp = tosa::SelectOp::create(rewriter, op->getLoc(), i32Type, cond, + truncMinusOne, intDivOp); Value result = tosa::tosaCastTensorToType(rewriter, selectOp, outType).value(); @@ -770,8 +771,8 @@ class ConvertAtenDivOp : public OpConversionPattern { // types can only be floating point for tosa::ReciprocalOp. rhsTensor = tosa::tosaCastTensorToType(rewriter, rhsTensor, outType).value(); - auto rhsRcp = rewriter.create( - op->getLoc(), rhsTensor.getType(), rhsTensor); + auto rhsRcp = tosa::ReciprocalOp::create(rewriter, op->getLoc(), + rhsTensor.getType(), rhsTensor); auto divResult = tosa::createMulOpAndCast(rewriter, op, outType, lhs, rhsRcp, /*shift=*/0); @@ -781,7 +782,7 @@ class ConvertAtenDivOp : public OpConversionPattern { // "floor": rounds the results of the division down. Equivalent to // floor division in Python (the // operator). auto floorOp = - rewriter.create(op->getLoc(), outType, divResult); + tosa::FloorOp::create(rewriter, op->getLoc(), outType, divResult); result = floorOp.getResult(); } else if (roundMode.compare("trunc") == 0) { @@ -810,8 +811,8 @@ class ConvertAtenDivOp : public OpConversionPattern { rhsTensor = tosa::tosaCastTensorToType(rewriter, rhsTensor, i32Type).value(); - auto intDivOp = rewriter.create(op->getLoc(), i32Type, - lhs, rhsTensor); + auto intDivOp = tosa::IntDivOp::create(rewriter, op->getLoc(), i32Type, + lhs, rhsTensor); result = tosa::tosaCastTensorToType(rewriter, intDivOp, outType).value(); @@ -946,8 +947,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Failed to equalize ranks among operands and result"); - auto cond = rewriter.create( - op->getLoc(), + auto cond = tosa::GreaterEqualOp::create( + rewriter, op->getLoc(), RankedTensorType::get(selfTy.getShape(), rewriter.getIntegerType(1)), self, zero); @@ -1237,13 +1238,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.getIntegerAttr(rewriter.getI64Type(), reduceDim); // Use default NaN Propagation mode "PROPAGATE" for tosa.argmax - return rewriter - .create( - op->getLoc(), getTypeConverter()->convertType(outputReduceTy), - input, reduceDimAttr, - /*nan_mode=*/ - tosa::NanPropagationModeAttr::get( - rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE)) + return tosa::ArgMaxOp::create( + rewriter, op->getLoc(), + getTypeConverter()->convertType(outputReduceTy), input, + reduceDimAttr, + /*nan_mode=*/ + tosa::NanPropagationModeAttr::get( + rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE)) .getResult(); }; @@ -1319,8 +1320,8 @@ class ConvertAtenSqueezeOp : public OpConversionPattern { auto newOutputTy = RankedTensorType::get( makeShapeLLVMCompatible(newOutputShape), resultElemTy); - auto reshapeOp = rewriter.create( - op->getLoc(), + auto reshapeOp = tosa::ReshapeOp::create( + rewriter, op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( newOutputTy), self, tosa::getTosaConstShape(rewriter, op->getLoc(), newOutputShape)); @@ -1559,8 +1560,8 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { auto rankBroadcastedLhs = lhsRank == maxInputRank ? lhs - : rewriter.create( - op->getLoc(), + : tosa::ReshapeOp::create( + rewriter, op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( lhsBroadcastedTy), lhs, @@ -1570,8 +1571,8 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { auto rankBroadcastedRhs = rhsRank == maxInputRank ? rhs - : rewriter.create( - op->getLoc(), + : tosa::ReshapeOp::create( + rewriter, op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( rhsBroadcastedTy), rhs, @@ -1600,8 +1601,8 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { auto newType = RankedTensorType::get(makeShapeLLVMCompatible(newShape), tensorTy.getElementType()); - return rewriter.create( - op->getLoc(), + return tosa::ReshapeOp::create( + rewriter, op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( newType), tensor, tosa::getTosaConstShape(rewriter, op->getLoc(), newShape)); @@ -1766,13 +1767,12 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { makeShapeLLVMCompatible(transposedLhsShape), rhsElemTy); lhsReshapeInput = - rewriter - .create( - op->getLoc(), - OpConversionPattern::getTypeConverter() - ->convertType(transposedLhsType), - rankBroadcastedLhs, - rewriter.getDenseI32ArrayAttr(transposedLhsDims)) + tosa::TransposeOp::create( + rewriter, op->getLoc(), + OpConversionPattern::getTypeConverter()->convertType( + transposedLhsType), + rankBroadcastedLhs, + rewriter.getDenseI32ArrayAttr(transposedLhsDims)) .getResult(); } @@ -1785,8 +1785,8 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { auto newLhsType = RankedTensorType::get( makeShapeLLVMCompatible(newLhsShape), lhsElemTy); - matmulLhs = rewriter.create( - op->getLoc(), + matmulLhs = tosa::ReshapeOp::create( + rewriter, op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( newLhsType), lhsReshapeInput, @@ -1842,18 +1842,17 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { if (rhsNeedsTranspose) transposedRhsValue = - rewriter - .create( - op->getLoc(), - OpConversionPattern::getTypeConverter() - ->convertType(transposedRhsType), - rankBroadcastedRhs, - rewriter.getDenseI32ArrayAttr(transposedRhsDims)) + tosa::TransposeOp::create( + rewriter, op->getLoc(), + OpConversionPattern::getTypeConverter()->convertType( + transposedRhsType), + rankBroadcastedRhs, + rewriter.getDenseI32ArrayAttr(transposedRhsDims)) .getResult(); // reshape - matmulRhs = rewriter.create( - op->getLoc(), + matmulRhs = tosa::ReshapeOp::create( + rewriter, op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( newRhsType), transposedRhsValue, @@ -1897,21 +1896,19 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { tosa::createZeroPointTensor(rewriter, op->getLoc(), rhsElemTy, 0) .value(); mmOpResult = - rewriter - .create( - op->getLoc(), - OpConversionPattern::getTypeConverter()->convertType( - mmOutputTy), - matmulLhs, matmulRhs, lhsZp, rhsZp) + tosa::MatMulOp::create( + rewriter, op->getLoc(), + OpConversionPattern::getTypeConverter()->convertType( + mmOutputTy), + matmulLhs, matmulRhs, lhsZp, rhsZp) .getResult(); } else { mmOpResult = - rewriter - .create( - op->getLoc(), - OpConversionPattern::getTypeConverter()->convertType( - mmOutputTy), - matmulLhs, matmulRhs) + tosa::MatMulOp::create( + rewriter, op->getLoc(), + OpConversionPattern::getTypeConverter()->convertType( + mmOutputTy), + matmulLhs, matmulRhs) .getResult(); } @@ -2016,8 +2013,8 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { // Perform reshape auto reshapedOpType = RankedTensorType::get( makeShapeLLVMCompatible(reshapedOpShape), accElemTy); - auto reshapedOp = rewriter.create( - op->getLoc(), + auto reshapedOp = tosa::ReshapeOp::create( + rewriter, op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( reshapedOpType), mmOpResult, @@ -2026,14 +2023,14 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { if (opNeedsTranspose) { auto transposedOpType = RankedTensorType::get( makeShapeLLVMCompatible(transposedOpShape), accElemTy); - output = rewriter - .create( - op->getLoc(), - OpConversionPattern::getTypeConverter() - ->convertType(transposedOpType), - reshapedOp.getResult(), - rewriter.getDenseI32ArrayAttr(transposedOpDims)) - .getResult(); + output = + tosa::TransposeOp::create( + rewriter, op->getLoc(), + OpConversionPattern::getTypeConverter()->convertType( + transposedOpType), + reshapedOp.getResult(), + rewriter.getDenseI32ArrayAttr(transposedOpDims)) + .getResult(); } else { output = reshapedOp.getResult(); @@ -2211,8 +2208,8 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { auto transposedRhsType = RankedTensorType::get( makeShapeLLVMCompatible(transposedRhsShape), rhsElemTy); - rhs = rewriter.create( - op->getLoc(), + rhs = tosa::TransposeOp::create( + rewriter, op->getLoc(), OpConversionPattern::getTypeConverter()->convertType( transposedRhsType), rhs, rewriter.getDenseI32ArrayAttr(transposedRhsDims)); @@ -2240,9 +2237,8 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { op, "Failed to equalize ranks among operands and result"); matmulPlusBias = - rewriter - .create(op->getLoc(), matmulPlusBias.getType(), - matmulPlusBias, bias) + tosa::AddOp::create(rewriter, op->getLoc(), matmulPlusBias.getType(), + matmulPlusBias, bias) .getResult(); } @@ -2402,11 +2398,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto transposedInputType = RankedTensorType::get( makeShapeLLVMCompatible(transposedInputShape), inputElemTy); auto transposedInput = - rewriter - .create( - op->getLoc(), - getTypeConverter()->convertType(transposedInputType), input, - rewriter.getDenseI32ArrayAttr(nchwToNhwcDims)) + tosa::TransposeOp::create( + rewriter, op->getLoc(), + getTypeConverter()->convertType(transposedInputType), input, + rewriter.getDenseI32ArrayAttr(nchwToNhwcDims)) .getResult(); SmallVector transformedWeightShape; @@ -2420,11 +2415,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( transformedWeightType = RankedTensorType::get( makeShapeLLVMCompatible(transformedWeightShape), weightElemTy); transformedWeight = - rewriter - .create( - op->getLoc(), - getTypeConverter()->convertType(transformedWeightType), weight, - rewriter.getDenseI32ArrayAttr(nchwToNhwcDims)) + tosa::TransposeOp::create( + rewriter, op->getLoc(), + getTypeConverter()->convertType(transformedWeightType), weight, + rewriter.getDenseI32ArrayAttr(nchwToNhwcDims)) .getResult(); outputCDim = transformedWeightShape[0]; } else if (weightShape[1] == 1) { @@ -2436,11 +2430,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto transposedWeightType = RankedTensorType::get( makeShapeLLVMCompatible(transposedWeightShape), weightElemTy); auto transposedWeight = - rewriter - .create( - op->getLoc(), - getTypeConverter()->convertType(transposedWeightType), weight, - rewriter.getDenseI32ArrayAttr(transposedDims)) + tosa::TransposeOp::create( + rewriter, op->getLoc(), + getTypeConverter()->convertType(transposedWeightType), weight, + rewriter.getDenseI32ArrayAttr(transposedDims)) .getResult(); // reshape: HWO(I/G) -> HWIM @@ -2459,13 +2452,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( transformedWeightType = RankedTensorType::get( makeShapeLLVMCompatible(transformedWeightShape), weightElemTy); transformedWeight = - rewriter - .create( - op->getLoc(), - getTypeConverter()->convertType(transformedWeightType), - transposedWeight, - tosa::getTosaConstShape(rewriter, op->getLoc(), - transformedWeightShape)) + tosa::ReshapeOp::create( + rewriter, op->getLoc(), + getTypeConverter()->convertType(transformedWeightType), + transposedWeight, + tosa::getTosaConstShape(rewriter, op->getLoc(), + transformedWeightShape)) .getResult(); } else { llvm_unreachable("Unhandled convolution type"); @@ -2566,24 +2558,22 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (groups == 1) { // full convolution convOpResult = - rewriter - .create( - op->getLoc(), getTypeConverter()->convertType(convOpTy), - transposedInput, transformedWeight, bias, inputZp, weightZp, - rewriter.getDenseI64ArrayAttr(padding), - rewriter.getDenseI64ArrayAttr(stride), - rewriter.getDenseI64ArrayAttr(dilation), accType) + tosa::Conv2DOp::create( + rewriter, op->getLoc(), getTypeConverter()->convertType(convOpTy), + transposedInput, transformedWeight, bias, inputZp, weightZp, + rewriter.getDenseI64ArrayAttr(padding), + rewriter.getDenseI64ArrayAttr(stride), + rewriter.getDenseI64ArrayAttr(dilation), accType) .getResult(); } else if (weightShape[1] == 1) { // depthwise convolution convOpResult = - rewriter - .create( - op->getLoc(), getTypeConverter()->convertType(convOpTy), - transposedInput, transformedWeight, bias, inputZp, weightZp, - rewriter.getDenseI64ArrayAttr(padding), - rewriter.getDenseI64ArrayAttr(stride), - rewriter.getDenseI64ArrayAttr(dilation), accType) + tosa::DepthwiseConv2DOp::create( + rewriter, op->getLoc(), getTypeConverter()->convertType(convOpTy), + transposedInput, transformedWeight, bias, inputZp, weightZp, + rewriter.getDenseI64ArrayAttr(padding), + rewriter.getDenseI64ArrayAttr(stride), + rewriter.getDenseI64ArrayAttr(dilation), accType) .getResult(); } else { llvm_unreachable("Unhandled convolution type"); @@ -2595,11 +2585,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto transposedOutputType = RankedTensorType::get( makeShapeLLVMCompatible(transposedOutputShape), biasElemTy); auto transposedOutput = - rewriter - .create( - op->getLoc(), - getTypeConverter()->convertType(transposedOutputType), - convOpResult, rewriter.getDenseI32ArrayAttr(nhwcToNchwDims)) + tosa::TransposeOp::create( + rewriter, op->getLoc(), + getTypeConverter()->convertType(transposedOutputType), convOpResult, + rewriter.getDenseI32ArrayAttr(nhwcToNchwDims)) .getResult(); Value rescaledResult = transposedOutput; @@ -2692,13 +2681,13 @@ std::optional computeBatchNorm(Operation *op, return std::nullopt; auto op1SubInputMean = - rewriter.create(op->getLoc(), outType, input, mean); + tosa::SubOp::create(rewriter, op->getLoc(), outType, input, mean); - auto op2AddVarEpsilon = rewriter.create( - op->getLoc(), variance.getType(), variance, eps); + auto op2AddVarEpsilon = tosa::AddOp::create( + rewriter, op->getLoc(), variance.getType(), variance, eps); - auto op3RsqrtOp2 = rewriter.create( - op->getLoc(), variance.getType(), op2AddVarEpsilon.getResult()); + auto op3RsqrtOp2 = tosa::RsqrtOp::create( + rewriter, op->getLoc(), variance.getType(), op2AddVarEpsilon.getResult()); auto op4MulOp1Op3 = tosa::createMulOpAndCast( rewriter, op, dyn_cast(outType), op1SubInputMean.getResult(), @@ -2708,9 +2697,8 @@ std::optional computeBatchNorm(Operation *op, tosa::createMulOpAndCast(rewriter, op, dyn_cast(outType), op4MulOp1Op3.getResult(), weight, 0); - return rewriter - .create(op->getLoc(), outType, op5MulOp4Scale.getResult(), - bias) + return tosa::AddOp::create(rewriter, op->getLoc(), outType, + op5MulOp4Scale.getResult(), bias) .getResult(); } @@ -2759,8 +2747,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto newType = RankedTensorType::get(makeShapeLLVMCompatible(newShape), outTensorType.getElementType()); - result = rewriter.create( - op->getLoc(), newType, toBcast, + result = tosa::ReshapeOp::create( + rewriter, op->getLoc(), newType, toBcast, tosa::getTosaConstShape(rewriter, op->getLoc(), newShape)); return success(); @@ -2887,15 +2875,15 @@ LogicalResult ConvertAtenOp::matchAndRewrite( makeShapeTorchCompatible(toReduceType.getShape())); for (int64_t i = toReduceShape.size() - 1; i >= meanAndVarShapeRank; i--) { toReduceShape[i] = 1; - sumDiv = rewriter.create( - op.getLoc(), + sumDiv = tosa::ReduceSumOp::create( + rewriter, op.getLoc(), RankedTensorType::get(makeShapeLLVMCompatible(toReduceShape), inputType.getElementType()), sumDiv, rewriter.getI32IntegerAttr(i)); } - return rewriter.create( - op.getLoc(), outType, sumDiv, + return tosa::ReshapeOp::create( + rewriter, op.getLoc(), outType, sumDiv, tosa::getTosaConstShape(rewriter, op->getLoc(), outShape)); }; @@ -2909,8 +2897,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( tosa::getConstTensor(rewriter, op.getOperation(), {static_cast(elemCnt)}, {1}, elemTy) .value(); - Value elemCntRcp = rewriter.create( - op.getLoc(), elemCntConst.getType(), elemCntConst); + Value elemCntRcp = tosa::ReciprocalOp::create( + rewriter, op.getLoc(), elemCntConst.getType(), elemCntConst); if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), input, elemCntRcp) .failed()) @@ -2935,7 +2923,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Compute variance. Value squareSumSub = - rewriter.create(op.getLoc(), inputType, input, meanVal); + tosa::SubOp::create(rewriter, op.getLoc(), inputType, input, meanVal); Value squareSum = tosa::createMulOpAndCast(rewriter, op, inputType, squareSumSub, squareSumSub, 0); @@ -2955,12 +2943,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto weightAndMeanBcastType = RankedTensorType::get( makeShapeLLVMCompatible(weightAndBiasBcastShape), elemTy); - Value weightVal = rewriter.create( - op.getLoc(), weightAndMeanBcastType, weight, + Value weightVal = tosa::ReshapeOp::create( + rewriter, op.getLoc(), weightAndMeanBcastType, weight, tosa::getTosaConstShape(rewriter, op->getLoc(), weightAndBiasBcastShape)); - Value biasVal = rewriter.create( - op.getLoc(), weightAndMeanBcastType, bias, + Value biasVal = tosa::ReshapeOp::create( + rewriter, op.getLoc(), weightAndMeanBcastType, bias, tosa::getTosaConstShape(rewriter, op->getLoc(), weightAndBiasBcastShape)); double eps; @@ -3065,8 +3053,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto newType = RankedTensorType::get(makeShapeLLVMCompatible(newShape), selfType.getElementType()); - auto reshapeOp = rewriter.create( - op.getLoc(), newType, adaptor.getSelf(), + auto reshapeOp = tosa::ReshapeOp::create( + rewriter, op.getLoc(), newType, adaptor.getSelf(), tosa::getTosaConstShape(rewriter, op->getLoc(), newShape)); rewriter.replaceOpWithNewOp( @@ -3196,9 +3184,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "Failed to equalize ranks among operands and result"); auto rcpOp = - rewriter.create(op.getLoc(), ln2Op.getType(), ln2Op); + tosa::ReciprocalOp::create(rewriter, op.getLoc(), ln2Op.getType(), ln2Op); - auto logOp = rewriter.create(op.getLoc(), outType, self); + auto logOp = tosa::LogOp::create(rewriter, op.getLoc(), outType, self); auto result = tosa::createMulOpAndCast(rewriter, op, outType, logOp, rcpOp, /*shift=*/0); @@ -3246,8 +3234,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Failed to equalize ranks among operands and result"); - auto cmpOp = rewriter.create( - op.getLoc(), + auto cmpOp = tosa::GreaterOp::create( + rewriter, op.getLoc(), RankedTensorType::get(selfType.getShape(), rewriter.getIntegerType(1)), self, threshold); @@ -3433,13 +3421,13 @@ buildUnitNormalCdf(ConversionPatternRewriter &rewriter, Operation *op, Value x, // buildNormalCdf, mean = zero, sigma = one auto outType = dyn_cast(x.getType()); auto mean = zero; - Value xMinusMean = rewriter.create(loc, outType, x, mean); + Value xMinusMean = tosa::SubOp::create(rewriter, loc, outType, x, mean); Value erfArg = tosa::createMulOpAndCast(rewriter, op, outType, xMinusMean, rsqrt2, /*shift=*/0); - Value erf = rewriter.create(loc, outType, erfArg); - Value erfPlus1 = rewriter.create(loc, outType, one, erf); + Value erf = tosa::ErfOp::create(rewriter, loc, outType, erfArg); + Value erfPlus1 = tosa::AddOp::create(rewriter, loc, outType, one, erf); Value normalCdf = tosa::createMulOpAndCast(rewriter, op, outType, oneHalf, erfPlus1, /*shift=*/0); @@ -3531,12 +3519,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( self, /*shift=*/0); // sqrt(2/pi) - auto sqrtTwoOverPi = - rewriter.create(op->getLoc(), resultType, twoOverPi, half); + auto sqrtTwoOverPi = tosa::PowOp::create(rewriter, op->getLoc(), resultType, + twoOverPi, half); // x^3 auto inputPowThree = - rewriter.create(op->getLoc(), resultType, self, three); + tosa::PowOp::create(rewriter, op->getLoc(), resultType, self, three); // 0.044715 * x^3 auto inputPowThreeMul = @@ -3544,8 +3532,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( inputPowThree.getResult(), /*shift=*/0); // x + 0.044715 * x^3 - auto inputPowThreeMulAdd = rewriter.create( - op->getLoc(), resultType, self, inputPowThreeMul.getResult()); + auto inputPowThreeMulAdd = tosa::AddOp::create( + rewriter, op->getLoc(), resultType, self, inputPowThreeMul.getResult()); // sqrt(2/pi) * (x + 0.044715 * x^3) auto sqrtTwoOverPiMul = tosa::createMulOpAndCast( @@ -3553,12 +3541,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( inputPowThreeMulAdd.getResult(), /*shift=*/0); // tanh(sqrt(2/pi) * (x + 0.044715 * x^3)) - auto tanh = rewriter.create(op->getLoc(), resultType, - sqrtTwoOverPiMul.getResult()); + auto tanh = tosa::TanhOp::create(rewriter, op->getLoc(), resultType, + sqrtTwoOverPiMul.getResult()); // 1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)) - auto tanhAdd = rewriter.create(op->getLoc(), resultType, one, - tanh.getResult()); + auto tanhAdd = tosa::AddOp::create(rewriter, op->getLoc(), resultType, one, + tanh.getResult()); auto result = tosa::createMulOpAndCast(rewriter, op, resultType, halfInput.getResult(), @@ -3622,14 +3610,14 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value negHalfInputSquared = tosa::createMulOpAndCast( rewriter, op, selfType, inputSquared, negOneHalf, /*shift=*/0); Value dinput = - rewriter.create(loc, selfType, negHalfInputSquared); + tosa::ExpOp::create(rewriter, loc, selfType, negHalfInputSquared); Value cdf = buildUnitNormalCdf(rewriter, op, self, selfElemTy).value(); Value dinputInput = tosa::createMulOpAndCast(rewriter, op, selfType, dinput, self, /*shift=*/0); Value dinputInputAlpha = tosa::createMulOpAndCast( rewriter, op, selfType, dinputInput, kAlphaHalf, /*shift=*/0); Value cdfExt = - rewriter.create(loc, selfType, dinputInputAlpha, cdf); + tosa::AddOp::create(rewriter, loc, selfType, dinputInputAlpha, cdf); auto resultTy = dyn_cast(getTypeConverter()->convertType(op.getType())); @@ -3706,18 +3694,18 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Type outType = getTypeConverter()->convertType(op.getType()); - Value lesser = rewriter.create( - op.getLoc(), + Value lesser = tosa::GreaterOp::create( + rewriter, op.getLoc(), RankedTensorType::get(selfType.getShape(), rewriter.getIntegerType(1)), minVal, self); - Value greater = rewriter.create( - op.getLoc(), + Value greater = tosa::GreaterOp::create( + rewriter, op.getLoc(), RankedTensorType::get(selfType.getShape(), rewriter.getIntegerType(1)), self, maxVal); - Value cmp = rewriter.create( - op.getLoc(), + Value cmp = tosa::LogicalOrOp::create( + rewriter, op.getLoc(), RankedTensorType::get(selfType.getShape(), rewriter.getIntegerType(1)), lesser, greater); @@ -3785,8 +3773,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( for (auto s : weightShape) newWeightShape.push_back(s); - auto reshapedWeight = rewriter.create( - op->getLoc(), + auto reshapedWeight = tosa::ReshapeOp::create( + rewriter, op->getLoc(), RankedTensorType::get(makeShapeLLVMCompatible(newWeightShape), weightType.getElementType()), weight, tosa::getTosaConstShape(rewriter, op->getLoc(), newWeightShape)); @@ -3800,8 +3788,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } SmallVector newIndicesShape = {1, numIndices}; - auto reshapedIndices = rewriter.create( - op->getLoc(), + auto reshapedIndices = tosa::ReshapeOp::create( + rewriter, op->getLoc(), RankedTensorType::get(makeShapeLLVMCompatible(newIndicesShape), indicesType.getElementType()), indices, @@ -3815,8 +3803,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( .value(); SmallVector intermediateOutShape = {1, numIndices, weightShape[1]}; - auto gatherOp = rewriter.create( - op->getLoc(), + auto gatherOp = tosa::GatherOp::create( + rewriter, op->getLoc(), RankedTensorType::get(makeShapeLLVMCompatible(intermediateOutShape), weightType.getElementType()), reshapedWeight, castIndices); @@ -3927,16 +3915,16 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { std::is_same()) { // Use default NaN Propagation mode "PROPAGATE" for tosa.reduce_min // and tosa.reduce_max - reduceOp = rewriter.create( - op->getLoc(), + reduceOp = TosaOpT::create( + rewriter, op->getLoc(), RankedTensorType::get(makeShapeLLVMCompatible(reducedShape), selfElemType), self, dimAttr, /*nan_mode=*/ tosa::NanPropagationModeAttr::get( rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE)); } else { - reduceOp = rewriter.create( - op->getLoc(), + reduceOp = TosaOpT::create( + rewriter, op->getLoc(), RankedTensorType::get(makeShapeLLVMCompatible(reducedShape), selfElemType), self, dimAttr); @@ -3947,11 +3935,11 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { Value argMaxOp; if constexpr (std::is_same()) { Value negateOp = - rewriter.create(op->getLoc(), selfType, self); + tosa::NegateOp::create(rewriter, op->getLoc(), selfType, self); // Use default NaN Propagation mode "PROPAGATE" for tosa.argmax - argMaxOp = rewriter.create( - op->getLoc(), + argMaxOp = tosa::ArgMaxOp::create( + rewriter, op->getLoc(), RankedTensorType::get(makeShapeLLVMCompatible(prunedShape), indicesElemType), negateOp, dimAttr, /*nan_mode=*/ @@ -3959,8 +3947,8 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE)); } else { // Use default NaN Propagation mode "PROPAGATE" for tosa.argmax - argMaxOp = rewriter.create( - op->getLoc(), + argMaxOp = tosa::ArgMaxOp::create( + rewriter, op->getLoc(), RankedTensorType::get(makeShapeLLVMCompatible(prunedShape), indicesElemType), self, dimAttr, /*nan_mode=*/ @@ -3969,14 +3957,14 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { } if (argMaxOp.getType() != indicesType) { - argMaxOp = rewriter.create( - op->getLoc(), indicesType, argMaxOp, + argMaxOp = tosa::ReshapeOp::create( + rewriter, op->getLoc(), indicesType, argMaxOp, tosa::getTosaConstShape(rewriter, op->getLoc(), reducedShape)); } if (!keepDim) { - reduceOp = rewriter.create( - op->getLoc(), + reduceOp = tosa::ReshapeOp::create( + rewriter, op->getLoc(), RankedTensorType::get(makeShapeLLVMCompatible(prunedShape), selfElemType), reduceOp, prunedShapeValue); @@ -4135,8 +4123,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // flexibility in rank differences and also offers more safety. Value reshapedInput = self; if (!llvm::equal(inputShape, targetInputShape)) - reshapedInput = rewriter.create( - op->getLoc(), + reshapedInput = tosa::ReshapeOp::create( + rewriter, op->getLoc(), RankedTensorType::get(makeShapeTorchCompatible(targetInputShape), selfElemTy), self, @@ -4154,8 +4142,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto tileOpMultiples = tosa::getTosaConstShape(rewriter, op->getLoc(), tileOpShape); - auto result = rewriter.create(op->getLoc(), resultType, - reshapedInput, tileOpMultiples); + auto result = tosa::TileOp::create(rewriter, op->getLoc(), resultType, + reshapedInput, tileOpMultiples); rewriter.replaceOp(op, {result.getResult()}); } @@ -4278,8 +4266,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (indexType.getRank() == 0) { indexShapeTorchCompatible = makeShapeTorchCompatible({1}); indexShape = indexShapeTorchCompatible; - index = rewriter.create( - op->getLoc(), + index = tosa::ReshapeOp::create( + rewriter, op->getLoc(), RankedTensorType::get(indexShape, indexType.getElementType()), index, tosa::getTosaConstShape(rewriter, op->getLoc(), indexShape)); } @@ -4334,8 +4322,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( RankedTensorType::get(makeShapeLLVMCompatible(indicesInputRankShape), rewriter.getIntegerType(32)); - auto reshapedIndices = rewriter.create( - op->getLoc(), indicesInputRankType, index, + auto reshapedIndices = tosa::ReshapeOp::create( + rewriter, op->getLoc(), indicesInputRankType, index, tosa::getTosaConstShape(rewriter, op->getLoc(), indicesInputRankShape)); SmallVector tileShape(indicesInputRankShape); @@ -4356,8 +4344,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto tileOpMultiples = tosa::getTosaConstShape(rewriter, op->getLoc(), tileShape); - auto expandedIndices = rewriter.create( - op->getLoc(), tileType, reshapedIndices.getResult(), tileOpMultiples); + auto expandedIndices = + tosa::TileOp::create(rewriter, op->getLoc(), tileType, + reshapedIndices.getResult(), tileOpMultiples); // convert torch style index and dim into tf style indices // tensor<[1,4,2],si64> -> tensor<[1,4,2,3],si64> @@ -4673,8 +4662,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto reshapeOutputTy = RankedTensorType::get( broadcastedShapeTf, idxType.getElementType()); // Update the tensor array with the max rank-extended form - indicesTfConcatTensors[i] = rewriter.create( - op->getLoc(), reshapeOutputTy, unreshapedIdxTensor, + indicesTfConcatTensors[i] = tosa::ReshapeOp::create( + rewriter, op->getLoc(), reshapeOutputTy, unreshapedIdxTensor, tosa::getTosaConstShape(rewriter, op->getLoc(), broadcastedShapeTf)); } @@ -4730,8 +4719,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto tileOpMultiples = tosa::getTosaConstShape(rewriter, op->getLoc(), tileOpShapeTf); - indicesTfConcatTensors[i] = rewriter.create( - op->getLoc(), tileOutputTy, reshapedIdxTensor, tileOpMultiples); + indicesTfConcatTensors[i] = + tosa::TileOp::create(rewriter, op->getLoc(), tileOutputTy, + reshapedIdxTensor, tileOpMultiples); } // Every index tensor now has the same rank and shape @@ -5115,15 +5105,15 @@ LogicalResult ConvertAtenOp::matchAndRewrite( otherType = dyn_cast(other.getType()); auto rhsSubOp = - rewriter.create(op->getLoc(), selfType, self, other); + tosa::SubOp::create(rewriter, op->getLoc(), selfType, self, other); auto rhsAbsOp = - rewriter.create(op->getLoc(), selfType, rhsSubOp); + tosa::AbsOp::create(rewriter, op->getLoc(), selfType, rhsSubOp); - auto lhsAbsOp = rewriter.create(op->getLoc(), otherType, other); + auto lhsAbsOp = tosa::AbsOp::create(rewriter, op->getLoc(), otherType, other); auto mulOp = tosa::createMulOpAndCast(rewriter, op, otherType, rtolConstOp, lhsAbsOp, /*shift=*/0); - auto addOp = - rewriter.create(op->getLoc(), otherType, atolConstOp, mulOp); + auto addOp = tosa::AddOp::create(rewriter, op->getLoc(), otherType, + atolConstOp, mulOp); auto outType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, outType, addOp, @@ -5354,16 +5344,16 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // max(xi, min_valuei) // Use default NaN Propagation mode "PROPAGATE" for tosa.maximum - auto minThresholdCheck = rewriter.create( - op->getLoc(), resultType, self, min, + auto minThresholdCheck = tosa::MaximumOp::create( + rewriter, op->getLoc(), resultType, self, min, /*nan_mode=*/ tosa::NanPropagationModeAttr::get(rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE)); // yi = min(max(xi, min_valuei), max_valuei) // Use default NaN Propagation mode "PROPAGATE" for tosa.minimum - auto result = rewriter.create( - op->getLoc(), resultType, minThresholdCheck, max, + auto result = tosa::MinimumOp::create( + rewriter, op->getLoc(), resultType, minThresholdCheck, max, /*nan_mode=*/ tosa::NanPropagationModeAttr::get(rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE)); @@ -5724,12 +5714,12 @@ class ConvertAtenRemainderFmodOp : public OpConversionPattern { if (isRemainderOp) { // torch.remainder(a, b) == a - a.div(b, rounding_mode="floor") * b if (isa(outElemTy)) { - auto otherTensorReciprocal = rewriter.create( - op.getLoc(), otherTensor.getType(), otherTensor); + auto otherTensorReciprocal = tosa::ReciprocalOp::create( + rewriter, op.getLoc(), otherTensor.getType(), otherTensor); divTensor = tosa::createMulOpAndCast( rewriter, op, outType, self, otherTensorReciprocal, /*shift=*/0); divTensor = - rewriter.create(op.getLoc(), outType, divTensor); + tosa::FloorOp::create(rewriter, op.getLoc(), outType, divTensor); } else { divTensor = floorIntDiv(rewriter, op, outType, self, otherTensor).value(); @@ -5746,8 +5736,8 @@ class ConvertAtenRemainderFmodOp : public OpConversionPattern { otherTensor = tosa::tosaCastTensorToType(rewriter, otherTensor, i32Type).value(); - auto intDivTensor = rewriter.create( - op->getLoc(), i32Type, self, otherTensor); + auto intDivTensor = tosa::IntDivOp::create(rewriter, op->getLoc(), + i32Type, self, otherTensor); divTensor = tosa::tosaCastTensorToType(rewriter, intDivTensor, outType).value(); @@ -5817,9 +5807,9 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern { SmallVector sizeSlice( dyn_cast(input.getType()).getShape()); sizeSlice[dimIndex] = inputDim - (remainderDim - padAfter); - input = rewriter.create( - loc, RankedTensorType::get(sizeSlice, inputElemTy), input, - tosa::getTosaConstShape(rewriter, loc, startSlice), + input = tosa::SliceOp::create( + rewriter, loc, RankedTensorType::get(sizeSlice, inputElemTy), + input, tosa::getTosaConstShape(rewriter, loc, startSlice), tosa::getTosaConstShape(rewriter, loc, sizeSlice)); dimSize = dimSize - padAfter; padAfter = 0; @@ -5861,10 +5851,9 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern { transposedInputShape.push_back(inputShape[dim]); auto transposedInputType = RankedTensorType::get( makeShapeLLVMCompatible(transposedInputShape), inputElemTy); - return rewriter - .create( - op->getLoc(), transposedInputType, input, - rewriter.getDenseI32ArrayAttr(transposedDims)) + return tosa::TransposeOp::create( + rewriter, op->getLoc(), transposedInputType, input, + rewriter.getDenseI32ArrayAttr(transposedDims)) .getResult(); } @@ -5909,8 +5898,8 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern { assert(inputTy.getRank() == 3 && "Expected input to be atleast 3 dimensional."); rank4Shape.insert(rank4Shape.begin(), 1); - input = rewriter.create( - loc, + input = tosa::ReshapeOp::create( + rewriter, loc, RankedTensorType::get(makeShapeTorchCompatible(rank4Shape), inputTy.getElementType()), input, tosa::getTosaConstShape(rewriter, loc, rank4Shape)); @@ -5951,22 +5940,20 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern { "Expected either tosa::MaxPool2dOp or tosa::AvgPool2dOp"); if constexpr (std::is_same::value) { // Use default NaN Propagation mode "PROPAGATE" for tosa.max_pool2d - pooledOutput = rewriter - .create( - op->getLoc(), outputTy, input, kernel, stride, pad, - /*nan_mode=*/ - tosa::NanPropagationModeAttr::get( - rewriter.getContext(), - tosa::NanPropagationMode::PROPAGATE)) - .getResult(); + pooledOutput = + TosaOpT::create( + rewriter, op->getLoc(), outputTy, input, kernel, stride, pad, + /*nan_mode=*/ + tosa::NanPropagationModeAttr::get( + rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE)) + .getResult(); } else if constexpr (std::is_same::value) { TypeAttr accType; if (failed(tosa::getAvgPool2dAccType(rewriter, input, accType))) return rewriter.notifyMatchFailure( op, "Failed to get accumulator type for pooling"); - pooledOutput = rewriter - .create(op->getLoc(), outputTy, input, kernel, - stride, pad, accType) + pooledOutput = TosaOpT::create(rewriter, op->getLoc(), outputTy, input, + kernel, stride, pad, accType) .getResult(); } @@ -5984,8 +5971,8 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern { auto resultShape = expectedResultTy.getShape(); auto resultElemTy = expectedResultTy.getElementType(); - result = rewriter.create( - op->getLoc(), + result = tosa::ReshapeOp::create( + rewriter, op->getLoc(), RankedTensorType::get(makeShapeTorchCompatible(resultShape), resultElemTy), transposedOutput, @@ -6277,13 +6264,11 @@ class ConvertAtenMaxPool1dOp SmallVector rank4Shape(selfShape); rank4Shape.push_back(1); auto reshapedSelf = - rewriter - .create( - op->getLoc(), - RankedTensorType::get(makeShapeTorchCompatible(rank4Shape), - selfTy.getElementType()), - self, - tosa::getTosaConstShape(rewriter, op->getLoc(), rank4Shape)) + tosa::ReshapeOp::create( + rewriter, op->getLoc(), + RankedTensorType::get(makeShapeTorchCompatible(rank4Shape), + selfTy.getElementType()), + self, tosa::getTosaConstShape(rewriter, op->getLoc(), rank4Shape)) .getResult(); SmallVector dilationArray; @@ -6379,13 +6364,11 @@ class ConvertAtenAvgPool1dOp SmallVector rank4Shape(selfShape); rank4Shape.push_back(1); auto reshapedSelf = - rewriter - .create( - op->getLoc(), - RankedTensorType::get(makeShapeTorchCompatible(rank4Shape), - selfTy.getElementType()), - self, - tosa::getTosaConstShape(rewriter, op->getLoc(), rank4Shape)) + tosa::ReshapeOp::create( + rewriter, op->getLoc(), + RankedTensorType::get(makeShapeTorchCompatible(rank4Shape), + selfTy.getElementType()), + self, tosa::getTosaConstShape(rewriter, op->getLoc(), rank4Shape)) .getResult(); SmallVector dilationArray{1, 1}; @@ -6515,16 +6498,16 @@ class ConvertAtenFillOp : public OpConversionPattern { makeShapeTorchCompatible(fillValueMatchedInputRankShape), fillValueElemTy); - auto fillValueMatchedInputRankTensor = rewriter.create( - op->getLoc(), fillValueMatchedInputRankType, fillValue, + auto fillValueMatchedInputRankTensor = tosa::ReshapeOp::create( + rewriter, op->getLoc(), fillValueMatchedInputRankType, fillValue, tosa::getTosaConstShape(rewriter, op->getLoc(), fillValueMatchedInputRankShape)); auto tileOpMultiples = tosa::getTosaConstShape(rewriter, op->getLoc(), outType.getShape()); - fillValueTargetTensor = rewriter.create( - op->getLoc(), + fillValueTargetTensor = tosa::TileOp::create( + rewriter, op->getLoc(), RankedTensorType::get(makeShapeTorchCompatible(outType.getShape()), fillValueElemTy), fillValueMatchedInputRankTensor.getResult(), tileOpMultiples); @@ -6701,8 +6684,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "Pad value needs to be a scalar constant for conversion to " "TOSA pad operation"); - padTensor = rewriter.create( - op->getLoc(), RankedTensorType::get({1}, selfElemTy), padTensor, + padTensor = tosa::ReshapeOp::create( + rewriter, op->getLoc(), RankedTensorType::get({1}, selfElemTy), padTensor, tosa::getTosaConstShape(rewriter, op->getLoc(), {1})); rewriter.replaceOpWithNewOp( @@ -6809,10 +6792,10 @@ ConvertAtenOp::matchAndRewrite( auto transposedInputTy = RankedTensorType::get( makeShapeLLVMCompatible(transposedInputShape), inputElemTy); auto transposedInput = - rewriter - .create( - op->getLoc(), getTypeConverter()->convertType(transposedInputTy), - input, rewriter.getDenseI32ArrayAttr(nchwToNhwcDims)) + tosa::TransposeOp::create( + rewriter, op->getLoc(), + getTypeConverter()->convertType(transposedInputTy), input, + rewriter.getDenseI32ArrayAttr(nchwToNhwcDims)) .getResult(); auto inputHeight = transposedInputShape[1]; @@ -6943,10 +6926,8 @@ ConvertAtenOp::matchAndRewrite( auto modeAttr = tosa::ResizeModeAttr::get(rewriter.getContext(), mode); auto resizeOpResult = - rewriter - .create(op->getLoc(), transposedResizedOpTy, - transposedInput, scale, offset, border, - modeAttr) + tosa::ResizeOp::create(rewriter, op->getLoc(), transposedResizedOpTy, + transposedInput, scale, offset, border, modeAttr) .getResult(); auto resultType = @@ -7099,8 +7080,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!isValidDim(dim, selfRank)) return rewriter.notifyMatchFailure(op, "Not all dims are valid"); - result = rewriter.create(op->getLoc(), resultTy, result, - static_cast(dim)); + result = tosa::ReverseOp::create(rewriter, op->getLoc(), resultTy, result, + static_cast(dim)); } rewriter.replaceOp(op, result); @@ -7153,42 +7134,43 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "Failed to equalize ranks among operands and result"); auto floorInput = - rewriter.create(op->getLoc(), resultTy, self); + tosa::FloorOp::create(rewriter, op->getLoc(), resultTy, self); // input - floor(input) - auto fractionalPart = rewriter.create( - op->getLoc(), resultTy, self, floorInput.getResult()); + auto fractionalPart = tosa::SubOp::create(rewriter, op->getLoc(), resultTy, + self, floorInput.getResult()); - auto ceilInput = rewriter.create(op->getLoc(), resultTy, self); + auto ceilInput = tosa::CeilOp::create(rewriter, op->getLoc(), resultTy, self); auto floorInputDivByTwo = tosa::createMulOpAndCast( rewriter, op, resultTy, floorInput.getResult(), oneHalf, /*shift=*/0); - auto floorDivResult = rewriter.create( - op->getLoc(), resultTy, floorInputDivByTwo.getResult()); + auto floorDivResult = tosa::FloorOp::create(rewriter, op->getLoc(), resultTy, + floorInputDivByTwo.getResult()); // (floor(input) // 2) * 2 auto evenComparison = tosa::createMulOpAndCast( rewriter, op, resultTy, floorDivResult.getResult(), two, /*shift=*/0); // floor(input) // 2) * 2 == input <=> floor(input) % 2 == 0 - auto floorInputEven = rewriter.create( - op->getLoc(), boolTy, floorInput.getResult(), evenComparison.getResult()); + auto floorInputEven = + tosa::EqualOp::create(rewriter, op->getLoc(), boolTy, + floorInput.getResult(), evenComparison.getResult()); - auto fracEqualOneHalf = rewriter.create( - op->getLoc(), boolTy, fractionalPart.getResult(), oneHalf); + auto fracEqualOneHalf = tosa::EqualOp::create( + rewriter, op->getLoc(), boolTy, fractionalPart.getResult(), oneHalf); - auto fracLtOneHalf = rewriter.create( - op->getLoc(), boolTy, oneHalf, fractionalPart.getResult()); + auto fracLtOneHalf = tosa::GreaterOp::create( + rewriter, op->getLoc(), boolTy, oneHalf, fractionalPart.getResult()); // (frac == 0.5) && (floor(input) % 2 == 0) - auto fracEqualOneHalfCond = rewriter.create( - op->getLoc(), boolTy, fracEqualOneHalf.getResult(), + auto fracEqualOneHalfCond = tosa::LogicalAndOp::create( + rewriter, op->getLoc(), boolTy, fracEqualOneHalf.getResult(), floorInputEven.getResult()); // (frac < 0.5) || ((frac == 0.5) && (floor(input) % 2 == 0)) - auto floorResultCond = rewriter.create( - op->getLoc(), boolTy, fracLtOneHalf.getResult(), + auto floorResultCond = tosa::LogicalOrOp::create( + rewriter, op->getLoc(), boolTy, fracLtOneHalf.getResult(), fracEqualOneHalfCond.getResult()); rewriter.replaceOpWithNewOp( @@ -7309,8 +7291,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( transposedInputType = RankedTensorType::get( makeShapeLLVMCompatible(transposedInputShape), selfElemTy); - selfTransposed = rewriter.create( - op->getLoc(), transposedInputType, self, + selfTransposed = tosa::TransposeOp::create( + rewriter, op->getLoc(), transposedInputType, self, rewriter.getDenseI32ArrayAttr(transposedDims)); } @@ -7368,8 +7350,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( llvm::to_vector(makeShapeTorchCompatible(transposedInputShape)); if (offset < 0) startSlice[targetDim1] = std::abs(offset); - diagonalTensor = rewriter.create( - op->getLoc(), transposedInputType, diagonalTensor, + diagonalTensor = tosa::SliceOp::create( + rewriter, op->getLoc(), transposedInputType, diagonalTensor, tosa::getTosaConstShape(rewriter, op->getLoc(), startSlice), tosa::getTosaConstShape(rewriter, op->getLoc(), sizeSlice)); } @@ -7474,8 +7456,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Reshape the input tensor to be the same shape as the new index tensor to // act as the src for scattering - auto scatterSrc = rewriter.create( - op->getLoc(), + auto scatterSrc = tosa::ReshapeOp::create( + rewriter, op->getLoc(), RankedTensorType::get(makeShapeTorchCompatible(indexShape), selfElemTy), self, tosa::getTosaConstShape(rewriter, op->getLoc(), indexShape)); @@ -7560,8 +7542,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( i++; } - auto result = rewriter.create( - op->getLoc(), resultType, diagonalTensor.value(), + auto result = tosa::TransposeOp::create( + rewriter, op->getLoc(), resultType, diagonalTensor.value(), rewriter.getDenseI32ArrayAttr(permutedDims)); rewriter.replaceOp(op, result.getResult()); @@ -7688,9 +7670,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( }); // Check: input <= threshold - auto cond = rewriter.create( - op->getLoc(), RankedTensorType::get(selfShape, rewriter.getI1Type()), - threshold, self); + auto cond = tosa::GreaterEqualOp::create( + rewriter, op->getLoc(), + RankedTensorType::get(selfShape, rewriter.getI1Type()), threshold, self); self = tosa::tosaCastTensorToType(rewriter, self, resultType).value(); grad = tosa::tosaCastTensorToType(rewriter, grad, resultType).value(); @@ -7700,8 +7682,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Failed to equalize ranks among operands and result"); - auto result = rewriter.create(op->getLoc(), resultType, - cond.getResult(), zero, grad); + auto result = tosa::SelectOp::create(rewriter, op->getLoc(), resultType, + cond.getResult(), zero, grad); rewriter.replaceOp(op, {result.getResult()}); @@ -7752,9 +7734,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( int64_t selfNumElems = std::accumulate(selfShape.begin(), selfShape.end(), 1, std::multiplies()); - auto self1D = rewriter.create( - op->getLoc(), RankedTensorType::get({selfNumElems}, selfElemTy), self, - tosa::getTosaConstShape(rewriter, op->getLoc(), {selfNumElems})); + auto self1D = tosa::ReshapeOp::create( + rewriter, op->getLoc(), RankedTensorType::get({selfNumElems}, selfElemTy), + self, tosa::getTosaConstShape(rewriter, op->getLoc(), {selfNumElems})); // Calculate the target elements indices SmallVector targetIndicesVec; @@ -7801,8 +7783,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!gatherOp) return rewriter.notifyMatchFailure(op, "Convert GatherNdOp failed"); - auto result = rewriter.create( - op->getLoc(), resultType, gatherOp.value(), + auto result = tosa::ReshapeOp::create( + rewriter, op->getLoc(), resultType, gatherOp.value(), tosa::getTosaConstShape(rewriter, op->getLoc(), outputSize)); rewriter.replaceOp(op, {result.getResult()}); @@ -7891,13 +7873,14 @@ Value reflectionPadAlongAxis(Value input, ArrayRef unpaddedShape, auto leftPadType = RankedTensorType::get(leftPadShape, inputElemTy); - auto leftPadSlice = rewriter.create( - loc, leftPadType, input, + auto leftPadSlice = tosa::SliceOp::create( + rewriter, loc, leftPadType, input, tosa::getTosaConstShape(rewriter, loc, leftStartSlice), tosa::getTosaConstShape(rewriter, loc, leftSizeSlice)); - auto leftPad = rewriter.create( - loc, leftPadType, leftPadSlice.getResult(), static_cast(axis)); + auto leftPad = tosa::ReverseOp::create(rewriter, loc, leftPadType, + leftPadSlice.getResult(), + static_cast(axis)); resultTensors.push_back(leftPad.getResult()); } @@ -7924,14 +7907,14 @@ Value reflectionPadAlongAxis(Value input, ArrayRef unpaddedShape, auto rightPadType = RankedTensorType::get(rightPadShape, inputElemTy); - auto rightPadSlice = rewriter.create( - loc, rightPadType, input, + auto rightPadSlice = tosa::SliceOp::create( + rewriter, loc, rightPadType, input, tosa::getTosaConstShape(rewriter, loc, rightStartSlice), tosa::getTosaConstShape(rewriter, loc, rightSizeSlice)); - auto rightPad = rewriter.create( - loc, rightPadType, rightPadSlice.getResult(), - static_cast(axis)); + auto rightPad = tosa::ReverseOp::create(rewriter, loc, rightPadType, + rightPadSlice.getResult(), + static_cast(axis)); resultTensors.push_back(rightPad.getResult()); } @@ -8171,8 +8154,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto leftPadSliceType = RankedTensorType::get(leftPadSliceShape, selfElemTy); - auto leftPadSlice = rewriter.create( - op->getLoc(), leftPadSliceType, self, + auto leftPadSlice = tosa::SliceOp::create( + rewriter, op->getLoc(), leftPadSliceType, self, tosa::getTosaConstShape(rewriter, op->getLoc(), leftStartSlice), tosa::getTosaConstShape(rewriter, op->getLoc(), leftSizeSlice)); @@ -8196,8 +8179,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto rightPadSliceType = RankedTensorType::get(rightPadSliceShape, selfElemTy); - auto rightPadSlice = rewriter.create( - op->getLoc(), rightPadSliceType, self, + auto rightPadSlice = tosa::SliceOp::create( + rewriter, op->getLoc(), rightPadSliceType, self, tosa::getTosaConstShape(rewriter, op->getLoc(), rightStartSlice), tosa::getTosaConstShape(rewriter, op->getLoc(), rightSizeSlice)); @@ -8231,8 +8214,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto topPadSliceType = RankedTensorType::get(topPadSliceShape, selfElemTy); - auto topPadSlice = rewriter.create( - op->getLoc(), topPadSliceType, selfSidePadded, + auto topPadSlice = tosa::SliceOp::create( + rewriter, op->getLoc(), topPadSliceType, selfSidePadded, tosa::getTosaConstShape(rewriter, op->getLoc(), topStartSlice), tosa::getTosaConstShape(rewriter, op->getLoc(), topSizeSlice)); @@ -8259,8 +8242,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto bottomPadSliceType = RankedTensorType::get(bottomPadSliceShape, selfElemTy); - auto bottomPadSlice = rewriter.create( - op->getLoc(), bottomPadSliceType, selfSidePadded, + auto bottomPadSlice = tosa::SliceOp::create( + rewriter, op->getLoc(), bottomPadSliceType, selfSidePadded, tosa::getTosaConstShape(rewriter, op->getLoc(), bottomStartSlice), tosa::getTosaConstShape(rewriter, op->getLoc(), bottomSizeSlice)); @@ -8352,8 +8335,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( SmallVector resultShapeIndex0Replaced({1, resultShape[1]}); // Reshape and tile self to shape {selfShape[0], resultShape[1]} - auto selfReshaped = rewriter.create( - op->getLoc(), + auto selfReshaped = tosa::ReshapeOp::create( + rewriter, op->getLoc(), RankedTensorType::get(resultShapeIndex1Replaced, resultType.getElementType()), self, @@ -8363,12 +8346,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto selfTileOpMultiples = tosa::getTosaConstShape(rewriter, op->getLoc(), resultShapeIndex0Replaced); - auto selfTiled = rewriter.create( - op->getLoc(), resultType, selfReshaped.getResult(), selfTileOpMultiples); + auto selfTiled = + tosa::TileOp::create(rewriter, op->getLoc(), resultType, + selfReshaped.getResult(), selfTileOpMultiples); // Reshape and tile vec2 to shape {resultShape[0], vec2Shape[0]} - auto vec2Reshaped = rewriter.create( - op->getLoc(), + auto vec2Reshaped = tosa::ReshapeOp::create( + rewriter, op->getLoc(), RankedTensorType::get(resultShapeIndex0Replaced, resultType.getElementType()), vec2, @@ -8378,8 +8362,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto vec2TileOpMultiples = tosa::getTosaConstShape(rewriter, op->getLoc(), resultShapeIndex1Replaced); - auto vec2Tiled = rewriter.create( - op->getLoc(), resultType, vec2Reshaped.getResult(), vec2TileOpMultiples); + auto vec2Tiled = + tosa::TileOp::create(rewriter, op->getLoc(), resultType, + vec2Reshaped.getResult(), vec2TileOpMultiples); auto result = tosa::createMulOpAndCast(rewriter, op, resultType, selfTiled.getResult(), @@ -8522,9 +8507,9 @@ class ConvertUpsampleNearest2dForward : public OpConversionPattern { selfShape.end() - 2); reshapedSelfShape.push_back(selfHeight * selfWidth); - auto reshapedSelf = rewriter.create( - op->getLoc(), RankedTensorType::get(reshapedSelfShape, selfElemTy), - self, + auto reshapedSelf = tosa::ReshapeOp::create( + rewriter, op->getLoc(), + RankedTensorType::get(reshapedSelfShape, selfElemTy), self, tosa::getTosaConstShape(rewriter, op->getLoc(), reshapedSelfShape)); // Calculate PyTorch-styled gather indices @@ -8565,8 +8550,8 @@ class ConvertUpsampleNearest2dForward : public OpConversionPattern { if (!gatherOp) return rewriter.notifyMatchFailure(op, "Convert GatherNdOp failed"); - auto result = rewriter.create( - op->getLoc(), resultType, gatherOp.value(), + auto result = tosa::ReshapeOp::create( + rewriter, op->getLoc(), resultType, gatherOp.value(), tosa::getTosaConstShape(rewriter, op->getLoc(), resultShape)); rewriter.replaceOp(op, {result.getResult()}); @@ -8634,14 +8619,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Clamp input to [eps, 1 - eps] when eps is not None // Use default NaN Propagation mode "PROPAGATE" for tosa.clamp if (!isEpsNone) { - zi = - rewriter - .create( - op->getLoc(), resultType, self, minFloatAttr, maxFloatAttr, - /*nan_mode=*/ - tosa::NanPropagationModeAttr::get( - rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE)) - .getResult(); + zi = tosa::ClampOp::create( + rewriter, op->getLoc(), resultType, self, minFloatAttr, + maxFloatAttr, + /*nan_mode=*/ + tosa::NanPropagationModeAttr::get( + rewriter.getContext(), tosa::NanPropagationMode::PROPAGATE)) + .getResult(); } auto one = @@ -8652,17 +8636,17 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "Failed to equalize ranks among operands and result"); auto oneMinusZi = - rewriter.create(op->getLoc(), resultType, one, zi); + tosa::SubOp::create(rewriter, op->getLoc(), resultType, one, zi); - auto oneMinusZiReciprocal = rewriter.create( - op->getLoc(), resultType, oneMinusZi.getResult()); + auto oneMinusZiReciprocal = tosa::ReciprocalOp::create( + rewriter, op->getLoc(), resultType, oneMinusZi.getResult()); auto mulOp = tosa::createMulOpAndCast(rewriter, op, resultType, zi, oneMinusZiReciprocal.getResult(), /*shift=*/0); - auto result = - rewriter.create(op->getLoc(), resultType, mulOp.getResult()); + auto result = tosa::LogOp::create(rewriter, op->getLoc(), resultType, + mulOp.getResult()); rewriter.replaceOp(op, {result.getResult()}); @@ -8703,10 +8687,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "Failed to equalize ranks among operands and result"); auto addOp = - rewriter.create(op->getLoc(), resultType, self, one); + tosa::AddOp::create(rewriter, op->getLoc(), resultType, self, one); - auto result = - rewriter.create(op->getLoc(), resultType, addOp.getResult()); + auto result = tosa::LogOp::create(rewriter, op->getLoc(), resultType, + addOp.getResult()); rewriter.replaceOp(op, {result.getResult()}); @@ -8747,15 +8731,17 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Failed to equalize ranks among operands and result"); - auto logOfSelf = rewriter.create(op->getLoc(), resultType, self); + auto logOfSelf = + tosa::LogOp::create(rewriter, op->getLoc(), resultType, self); auto constTenType = RankedTensorType::get( dyn_cast(ten.getType()).getShape(), resultElemTy); - auto logOfTen = rewriter.create(op->getLoc(), constTenType, ten); + auto logOfTen = + tosa::LogOp::create(rewriter, op->getLoc(), constTenType, ten); - auto reciprocalOp = rewriter.create( - op->getLoc(), constTenType, logOfTen.getResult()); + auto reciprocalOp = tosa::ReciprocalOp::create( + rewriter, op->getLoc(), constTenType, logOfTen.getResult()); auto result = tosa::createMulOpAndCast( rewriter, op, resultType, logOfSelf.getResult(), reciprocalOp.getResult(), @@ -8801,10 +8787,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Failed to equalize ranks among operands and result"); - auto expOp = rewriter.create(op->getLoc(), resultType, self); + auto expOp = tosa::ExpOp::create(rewriter, op->getLoc(), resultType, self); - auto result = rewriter.create(op->getLoc(), resultType, - expOp.getResult(), one); + auto result = tosa::SubOp::create(rewriter, op->getLoc(), resultType, + expOp.getResult(), one); rewriter.replaceOp(op, {result.getResult()}); @@ -8835,12 +8821,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!isa(selfType.getElementType())) self = tosa::tosaCastTensorToType(rewriter, self, resultType).value(); - auto sinOp = rewriter.create(op->getLoc(), resultType, self); + auto sinOp = tosa::SinOp::create(rewriter, op->getLoc(), resultType, self); - auto cosOp = rewriter.create(op->getLoc(), resultType, self); + auto cosOp = tosa::CosOp::create(rewriter, op->getLoc(), resultType, self); auto reciprocalOp = - rewriter.create(op->getLoc(), resultType, cosOp); + tosa::ReciprocalOp::create(rewriter, op->getLoc(), resultType, cosOp); auto result = tosa::createMulOpAndCast( rewriter, op, resultType, sinOp.getResult(), reciprocalOp.getResult(), @@ -8915,8 +8901,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Unsupported size value for rank zero input"); - auto result = rewriter.create( - op->getLoc(), RankedTensorType::get({1}, selfElemTy), self, + auto result = tosa::ReshapeOp::create( + rewriter, op->getLoc(), RankedTensorType::get({1}, selfElemTy), self, tosa::getTosaConstShape(rewriter, op->getLoc(), {1})); rewriter.replaceOp(op, {result.getResult()}); @@ -9017,8 +9003,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } } - auto reshapeOp = rewriter.create( - op->getLoc(), RankedTensorType::get(intermediaryShape, resultElemTy), + auto reshapeOp = tosa::ReshapeOp::create( + rewriter, op->getLoc(), + RankedTensorType::get(intermediaryShape, resultElemTy), gatherNdOp.value(), tosa::getTosaConstShape(rewriter, op->getLoc(), intermediaryShape)); @@ -9030,8 +9017,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } permutedDims.push_back(static_cast(dim + 1)); - auto result = rewriter.create( - op->getLoc(), resultType, reshapeOp.getResult(), + auto result = tosa::TransposeOp::create( + rewriter, op->getLoc(), resultType, reshapeOp.getResult(), rewriter.getDenseI32ArrayAttr(permutedDims)); rewriter.replaceOp(op, {result.getResult()}); diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index 3fc11f4fa13f..1d156f349f4f 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -217,7 +217,7 @@ Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op, auto const_attr = DenseElementsAttr::get(const_type, val); auto const_op = - rewriter.create(op->getLoc(), const_type, const_attr); + tosa::ConstOp::create(rewriter, op->getLoc(), const_type, const_attr); return const_op.getResult(); } @@ -229,7 +229,7 @@ Value getTosaMulShiftConstTensor(PatternRewriter &rewriter, Operation *op, shiftType, rewriter.getIntegerAttr(rewriter.getIntegerType(8), shift)); auto constShift = - rewriter.create(op->getLoc(), shiftType, shiftAttr); + tosa::ConstOp::create(rewriter, op->getLoc(), shiftType, shiftAttr); return constShift.getResult(); } @@ -280,7 +280,7 @@ std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, auto const_attr = DenseElementsAttr::get(const_type, vec); auto const_op = - rewriter.create(op->getLoc(), const_type, const_attr); + tosa::ConstOp::create(rewriter, op->getLoc(), const_type, const_attr); if (dtype) { return tosa::tosaCastTensorToType(rewriter, const_op, @@ -311,7 +311,7 @@ std::optional getConstTensor(PatternRewriter &rewriter, auto const_attr = DenseElementsAttr::get(const_type, vec); auto const_op = - rewriter.create(op->getLoc(), const_type, const_attr); + tosa::ConstOp::create(rewriter, op->getLoc(), const_type, const_attr); if (dtype) { return tosa::tosaCastTensorToType(rewriter, const_op, @@ -341,7 +341,7 @@ std::optional getConstTensor(PatternRewriter &rewriter, auto const_attr = DenseElementsAttr::get(const_type, vec); auto const_op = - rewriter.create(op->getLoc(), const_type, const_attr); + tosa::ConstOp::create(rewriter, op->getLoc(), const_type, const_attr); if (dtype) { return tosa::tosaCastTensorToType(rewriter, const_op, @@ -449,18 +449,18 @@ std::optional tosaCastTensorToType(PatternRewriter &rewriter, Value src, auto cmpTy = srcType.clone(rewriter.getIntegerType(1)); Value isEq = - rewriter.create(op->getLoc(), cmpTy, src, zeroValue); - return rewriter.create(op->getLoc(), - srcType.clone(destElemTy), isEq); + tosa::EqualOp::create(rewriter, op->getLoc(), cmpTy, src, zeroValue); + return tosa::LogicalNotOp::create(rewriter, op->getLoc(), + srcType.clone(destElemTy), isEq); } if (srcElemTy.isInteger(1) && llvm::isa(destElemTy)) { // TOSA does not support casting from i1->float. // Instead, we cast to i8 and then to the float. TensorType midType = srcType.clone(rewriter.getIntegerType(8)); - Value mid = rewriter.create(op->getLoc(), midType, src); - return rewriter.create(op->getLoc(), - srcType.clone(destElemTy), mid); + Value mid = tosa::CastOp::create(rewriter, op->getLoc(), midType, src); + return tosa::CastOp::create(rewriter, op->getLoc(), + srcType.clone(destElemTy), mid); } if (srcElemTy == destElemTy) @@ -472,8 +472,8 @@ std::optional tosaCastTensorToType(PatternRewriter &rewriter, Value src, // PyTorch performs round-to-zero instead. // Generate round-to-zero conversion prior to tosa.cast to match with // expected torch behavior. - auto floor = rewriter.create(op->getLoc(), srcType, src); - auto ceil = rewriter.create(op->getLoc(), srcType, src); + auto floor = tosa::FloorOp::create(rewriter, op->getLoc(), srcType, src); + auto ceil = tosa::CeilOp::create(rewriter, op->getLoc(), srcType, src); auto zeroValue = tosa::getConstTensor(rewriter, op, 0, {}, srcElemTy).value(); @@ -490,7 +490,7 @@ std::optional tosaCastTensorToType(PatternRewriter &rewriter, Value src, } TensorType castedSrcType = srcType.clone(destElemTy); - return rewriter.create(op->getLoc(), castedSrcType, src); + return tosa::CastOp::create(rewriter, op->getLoc(), castedSrcType, src); } // Template instantiation diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index 3a5a5a7447c8..0bff646023cf 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -54,13 +54,13 @@ Value toPositiveDimDynamic(OpBuilder &b, Location loc, Value dim, Value inputRank) { assert(isa(dim.getType()) && "dim arg of toPositiveDim must be integer type"); - Value dimAddInputRank = b.create(loc, dim, inputRank); + Value dimAddInputRank = arith::AddIOp::create(b, loc, dim, inputRank); Value cst0 = - b.create(loc, b.getZeroAttr(inputRank.getType())); + arith::ConstantOp::create(b, loc, b.getZeroAttr(inputRank.getType())); Value predDimGEZero = - b.create(loc, arith::CmpIPredicate::sge, dim, cst0); + arith::CmpIOp::create(b, loc, arith::CmpIPredicate::sge, dim, cst0); Value dimInt = - b.create(loc, predDimGEZero, dim, dimAddInputRank); + arith::SelectOp::create(b, loc, predDimGEZero, dim, dimAddInputRank); return dimInt; } @@ -69,15 +69,15 @@ void assertIsValidDim(OpBuilder &b, Location loc, Value dim, Value inputRank) { assert(isa(dim.getType()) && "dim arg of assertIsValidDim must be integer type"); Value cst0 = - b.create(loc, b.getZeroAttr(inputRank.getType())); + arith::ConstantOp::create(b, loc, b.getZeroAttr(inputRank.getType())); Value predGEZero = - b.create(loc, arith::CmpIPredicate::sge, dim, cst0); - b.create( - loc, predGEZero, b.getStringAttr("dim must be greater or equal to zero")); + arith::CmpIOp::create(b, loc, arith::CmpIPredicate::sge, dim, cst0); + cf::AssertOp::create(b, loc, predGEZero, + b.getStringAttr("dim must be greater or equal to zero")); Value predLTInputRank = - b.create(loc, arith::CmpIPredicate::slt, dim, inputRank); - b.create(loc, predLTInputRank, - b.getStringAttr("dim must be smaller than inputRank")); + arith::CmpIOp::create(b, loc, arith::CmpIPredicate::slt, dim, inputRank); + cf::AssertOp::create(b, loc, predLTInputRank, + b.getStringAttr("dim must be smaller than inputRank")); } // Hack to deal with the Torch list type arguments which is not supported end @@ -113,10 +113,10 @@ void checkDimEqualHelper(OpBuilder &b, Location loc, Value lhsDim, lhsType.isIndex() ? castIndexToInt64(b, loc, lhsDim) : lhsDim; Value rhsDimInt = rhsType.isIndex() ? castIndexToInt64(b, loc, rhsDim) : rhsDim; - Value contractingDimEqual = b.create( - loc, arith::CmpIPredicate::eq, lhsDimInt, rhsDimInt); - b.create(loc, contractingDimEqual, - b.getStringAttr("mismatching contracting dimension")); + Value contractingDimEqual = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::eq, lhsDimInt, rhsDimInt); + cf::AssertOp::create(b, loc, contractingDimEqual, + b.getStringAttr("mismatching contracting dimension")); } // Creates a tensor with required `sizes` and `elemTy` and fills it with @@ -124,34 +124,34 @@ void checkDimEqualHelper(OpBuilder &b, Location loc, Value lhsDim, Value createInitTensor(OpBuilder &b, Location loc, ValueRange sizes, Type elemTy, Value initElem) { Value initTensor = - b.create(loc, getAsOpFoldResult(sizes), elemTy); - return b.create(loc, initElem, initTensor).getResult(0); + tensor::EmptyOp::create(b, loc, getAsOpFoldResult(sizes), elemTy); + return linalg::FillOp::create(b, loc, initElem, initTensor).getResult(0); } Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes, Type elemTy) { Value initTensor = - b.create(loc, getAsOpFoldResult(sizes), elemTy); + tensor::EmptyOp::create(b, loc, getAsOpFoldResult(sizes), elemTy); Type fillValElemTy = elemTy; if (auto dtypeComplex = dyn_cast(elemTy)) fillValElemTy = cast(dtypeComplex.getElementType()); - Value c0 = b.create(loc, b.getZeroAttr(fillValElemTy)); - return b.create(loc, c0, initTensor).getResult(0); + Value c0 = arith::ConstantOp::create(b, loc, b.getZeroAttr(fillValElemTy)); + return linalg::FillOp::create(b, loc, c0, initTensor).getResult(0); } Value createOneInitTensor(OpBuilder &b, Location loc, ValueRange sizes, Type elemTy) { Value initTensor = - b.create(loc, getAsOpFoldResult(sizes), elemTy); + tensor::EmptyOp::create(b, loc, getAsOpFoldResult(sizes), elemTy); Type fillValElemTy = elemTy; if (auto dtypeComplex = dyn_cast(elemTy)) fillValElemTy = cast(dtypeComplex.getElementType()); - Value c1 = b.create(loc, b.getOneAttr(fillValElemTy)); - return b.create(loc, c1, initTensor).getResult(0); + Value c1 = arith::ConstantOp::create(b, loc, b.getOneAttr(fillValElemTy)); + return linalg::FillOp::create(b, loc, c1, initTensor).getResult(0); } Value castIntToIndex(OpBuilder &b, Location loc, Value v) { @@ -205,9 +205,9 @@ SmallVector getTensorSizes(OpBuilder &b, Location loc, Value tensor) { Value getTensorSize(OpBuilder &b, Location loc, Value tensor) { SmallVector sizes(getTensorSizes(b, loc, tensor)); - Value productResult = b.create(loc, b.getIndexAttr(1)); + Value productResult = arith::ConstantOp::create(b, loc, b.getIndexAttr(1)); for (Value size : sizes) - productResult = b.create(loc, productResult, size); + productResult = arith::MulIOp::create(b, loc, productResult, size); return castIndexToInt64(b, loc, productResult); } @@ -223,21 +223,21 @@ Value getConstant(OpBuilder &b, Location loc, int64_t val, Type elemType) { APInt(cast(elemType).getWidth(), val)); if (!attr) return nullptr; - return b.create(loc, elemType, attr); + return arith::ConstantOp::create(b, loc, elemType, attr); } SmallVector getAsConstantIntValues(OpBuilder &b, Location loc, SmallVectorImpl &ints) { return llvm::to_vector<4>(llvm::map_range(ints, [&](int64_t val) -> Value { - return b.create(loc, - b.getIntegerAttr(b.getI64Type(), val)); + return arith::ConstantOp::create(b, loc, + b.getIntegerAttr(b.getI64Type(), val)); })); } SmallVector getAsConstantIndexValues(OpBuilder &b, Location loc, SmallVectorImpl &ints) { return llvm::to_vector<4>(llvm::map_range(ints, [&](int64_t val) -> Value { - return b.create(loc, b.getIndexAttr(val)); + return arith::ConstantOp::create(b, loc, b.getIndexAttr(val)); })); } @@ -318,13 +318,14 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, // If the dtype is i1, i.e., a boolean type. if (dtype.isSignlessInteger(1)) { Type scalarType = scalar.getType(); - Value cstZero = b.create(loc, b.getZeroAttr(scalarType)); + Value cstZero = + arith::ConstantOp::create(b, loc, b.getZeroAttr(scalarType)); if (isa(scalarType)) { - return b.create(loc, arith::CmpFPredicate::UNE, scalar, - cstZero); + return arith::CmpFOp::create(b, loc, arith::CmpFPredicate::UNE, scalar, + cstZero); } else if (isa(scalarType)) { - return b.create(loc, arith::CmpIPredicate::ne, scalar, - cstZero); + return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::ne, scalar, + cstZero); } else { mlir::emitError(loc) << "unsupported scalar type for convertScalarToDtype " << scalarType @@ -336,37 +337,37 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, if (auto dtypeFloat = dyn_cast(dtype)) { if (auto scalarFloat = dyn_cast(scalarType)) { if (scalarFloat.getWidth() == 16 && dtypeFloat.getWidth() == 16) { - auto scalarF32 = b.create(loc, b.getF32Type(), scalar); - return b.create(loc, dtype, scalarF32); + auto scalarF32 = arith::ExtFOp::create(b, loc, b.getF32Type(), scalar); + return arith::TruncFOp::create(b, loc, dtype, scalarF32); } if (scalarFloat.getWidth() > dtypeFloat.getWidth()) - return b.create(loc, dtype, scalar); + return arith::TruncFOp::create(b, loc, dtype, scalar); // Only scalarFloat width < dtypeFloat width can reach here. - return b.create(loc, dtype, scalar); + return arith::ExtFOp::create(b, loc, dtype, scalar); } assert(isa(scalarType)); if (scalarType.isSignlessInteger(1) || (srcOriginalDtype.has_value() && srcOriginalDtype->isUnsignedInteger())) - return b.create(loc, dtype, scalar); + return arith::UIToFPOp::create(b, loc, dtype, scalar); // It's safe to use SIToFPOp because ui8/si8 are the only ones where // unsigned handling is needed, and we checked for that case above. - return b.create(loc, dtype, scalar); + return arith::SIToFPOp::create(b, loc, dtype, scalar); } if (auto dtypeInteger = dyn_cast(dtype)) { if (auto scalarFloat = dyn_cast(scalarType)) - return b.create(loc, dtype, scalar); + return arith::FPToSIOp::create(b, loc, dtype, scalar); assert(isa(scalarType)); auto scalarInteger = cast(scalarType); if (scalarInteger.getWidth() > dtypeInteger.getWidth()) - return b.create(loc, dtype, scalar); + return arith::TruncIOp::create(b, loc, dtype, scalar); if (scalarType.isSignlessInteger(1) || (srcOriginalDtype.has_value() && srcOriginalDtype->isUnsignedInteger())) - return b.create(loc, dtype, scalar); + return arith::ExtUIOp::create(b, loc, dtype, scalar); // Only scalarInteger width < dtypeInteger width can reach here. // It's safe to use ExtSIOp here because ui8/si8 are the only ones where // unsigned handling is needed, and we checked for that case above. - return b.create(loc, dtype, scalar); + return arith::ExtSIOp::create(b, loc, dtype, scalar); } if (auto dtypeComplex = dyn_cast(dtype)) { @@ -378,13 +379,13 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, // Extract the real and imaginary parts of the scalar. // Cast them to the target element type, and create a new complex // value with the target complex type. - Value realVal = b.create(loc, scalar); - Value imgVal = b.create(loc, scalar); + Value realVal = complex::ReOp::create(b, loc, scalar); + Value imgVal = complex::ImOp::create(b, loc, scalar); realVal = convertScalarToDtype(b, loc, realVal, dtypeElemType); imgVal = convertScalarToDtype(b, loc, imgVal, dtypeElemType); - return b.create(loc, dtypeComplex, realVal, imgVal); + return complex::CreateOp::create(b, loc, dtypeComplex, realVal, imgVal); } // Float to complex type. @@ -393,17 +394,17 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, cast(dtypeComplex.getElementType()); Value realVal; Value imgVal = - b.create(loc, b.getZeroAttr(complexElementType)); + arith::ConstantOp::create(b, loc, b.getZeroAttr(complexElementType)); if (complexElementType.getWidth() > dtypeFloat.getWidth()) { - realVal = b.create(loc, complexElementType, scalar); + realVal = arith::ExtFOp::create(b, loc, complexElementType, scalar); } else if (complexElementType.getWidth() < dtypeFloat.getWidth()) { - realVal = b.create(loc, complexElementType, scalar); + realVal = arith::TruncFOp::create(b, loc, complexElementType, scalar); } else { realVal = scalar; } - return b.create(loc, dtypeComplex, realVal, imgVal); + return complex::CreateOp::create(b, loc, dtypeComplex, realVal, imgVal); } // Int to complex type. @@ -412,11 +413,11 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, cast(dtypeComplex.getElementType()); Value realVal = - b.create(loc, complexElementType, scalar); + arith::SIToFPOp::create(b, loc, complexElementType, scalar); Value imgVal = - b.create(loc, b.getZeroAttr(complexElementType)); + arith::ConstantOp::create(b, loc, b.getZeroAttr(complexElementType)); - return b.create(loc, dtypeComplex, realVal, imgVal); + return complex::CreateOp::create(b, loc, dtypeComplex, realVal, imgVal); } mlir::emitError(loc) << "unsupported scalar type for convertScalarToDtype " @@ -436,17 +437,17 @@ Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc, Value positiveDim = toPositiveDimDynamic(rewriter, loc, builtinInt, dimSizeAsInt); // positiveDim < 0 ? 0 : positiveDim - Value cst0 = rewriter.create( - loc, rewriter.getZeroAttr(dimSizeAsInt.getType())); - Value predDimSltZero = rewriter.create( - loc, arith::CmpIPredicate::slt, positiveDim, cst0); + Value cst0 = arith::ConstantOp::create( + rewriter, loc, rewriter.getZeroAttr(dimSizeAsInt.getType())); + Value predDimSltZero = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::slt, positiveDim, cst0); Value atLeastZero = - rewriter.create(loc, predDimSltZero, cst0, positiveDim); + arith::SelectOp::create(rewriter, loc, predDimSltZero, cst0, positiveDim); // atLeastZero > dimSizeAsInt ? dimSizeAsInt : atLeastZero - Value sgtDimSize = rewriter.create( - loc, arith::CmpIPredicate::sgt, atLeastZero, dimSizeAsInt); - Value boundedByDimSize = rewriter.create( - loc, sgtDimSize, dimSizeAsInt, atLeastZero); + Value sgtDimSize = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::sgt, atLeastZero, dimSizeAsInt); + Value boundedByDimSize = arith::SelectOp::create(rewriter, loc, sgtDimSize, + dimSizeAsInt, atLeastZero); return castIntToIndex(rewriter, loc, boundedByDimSize); } @@ -491,8 +492,8 @@ FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, } } } - Value unsqueezed = rewriter.create( - op->getLoc(), unsqueezedType, input, reassociationMap); + Value unsqueezed = tensor::ExpandShapeOp::create( + rewriter, op->getLoc(), unsqueezedType, input, reassociationMap); return unsqueezed; } @@ -514,13 +515,13 @@ FailureOr squeezeTensor(PatternRewriter &rewriter, Operation *op, // assert dynamic squeeze dim size == 1 if (inputType.isDynamicDim(dim)) { - Value cstDim = rewriter.create(loc, dim); - Value dimVal = rewriter.create(loc, input, cstDim); - Value cstOne = rewriter.create(loc, 1); - Value cmp = rewriter.create(loc, arith::CmpIPredicate::eq, - dimVal, cstOne); - rewriter.create( - loc, cmp, + Value cstDim = arith::ConstantIndexOp::create(rewriter, loc, dim); + Value dimVal = tensor::DimOp::create(rewriter, loc, input, cstDim); + Value cstOne = arith::ConstantIndexOp::create(rewriter, loc, 1); + Value cmp = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, + dimVal, cstOne); + cf::AssertOp::create( + rewriter, loc, cmp, rewriter.getStringAttr( "Expected dynamic squeeze dim size to be statically 1")); } @@ -559,8 +560,8 @@ FailureOr squeezeTensor(PatternRewriter &rewriter, Operation *op, // Note: In case the operand tensor type is of unit rank and is statically // shaped with unit dimension, the `reassociationMap` will be empty and the // input will be collapsed to a 0-D tensor. - Value squeezed = rewriter.create( - op->getLoc(), squeezedType, input, reassociationMap); + Value squeezed = tensor::CollapseShapeOp::create( + rewriter, op->getLoc(), squeezedType, input, reassociationMap); return squeezed; } diff --git a/lib/Dialect/TMTensor/IR/TMTensorOps.cpp b/lib/Dialect/TMTensor/IR/TMTensorOps.cpp index 75db622da759..a05f770e2793 100644 --- a/lib/Dialect/TMTensor/IR/TMTensorOps.cpp +++ b/lib/Dialect/TMTensor/IR/TMTensorOps.cpp @@ -68,10 +68,10 @@ Value TMTensor::getDimValue(OpBuilder &builder, Location loc, Value v, int64_t dim) { return TypeSwitch(v.getType()) .Case([&](RankedTensorType t) -> Value { - return builder.create(loc, v, dim); + return tensor::DimOp::create(builder, loc, v, dim); }) .Case([&](MemRefType t) -> Value { - return builder.create(loc, v, dim); + return memref::DimOp::create(builder, loc, v, dim); }) .Default([&](Type t) { return Value(); }); } @@ -160,39 +160,39 @@ static void matmul(OpBuilder &b, Location loc, Value lhs, ValueRange lhsSizes, Value rhs, ValueRange rhsSizes, Value output, ValueRange outputSizes, bool transposed = false) { auto elementType = cast(lhs.getType()).getElementType(); - Value one = b.create(loc, 1); - Value zero = b.create(loc, 0); + Value one = arith::ConstantIndexOp::create(b, loc, 1); + Value zero = arith::ConstantIndexOp::create(b, loc, 0); auto rank = outputSizes.size(); Value reductionDimSize = lhsSizes[lhsSizes.size() - 1]; // Loop over output - b.create( - loc, SmallVector(rank, zero), outputSizes, + scf::ParallelOp::create( + b, loc, SmallVector(rank, zero), outputSizes, SmallVector(rank, one), [&](OpBuilder &b, Location loc, ValueRange localIVs) { - Value acc = b.create( - loc, elementType, b.getFloatAttr(elementType, 0.0)); + Value acc = arith::ConstantOp::create(b, loc, elementType, + b.getFloatAttr(elementType, 0.0)); Value sum = - b.create( - loc, zero, reductionDimSize, one, SmallVector{acc}, - [&](OpBuilder &b, Location loc, Value i, ValueRange accs) { - SmallVector lhsIVs(localIVs), rhsIVs(localIVs); - lhsIVs[lhsIVs.size() - 1] = i; - rhsIVs[rhsIVs.size() - 2] = i; - if (transposed) - std::swap(rhsIVs[rhsIVs.size() - 1], - rhsIVs[rhsIVs.size() - 2]); - - Value acc = accs[0]; - Value rElem = b.create(loc, lhs, lhsIVs); - Value cElem = b.create(loc, rhs, rhsIVs); - Value x = b.create(loc, rElem, cElem); - x = b.create(loc, x, acc); - - b.create(loc, x); - }) + scf::ForOp::create( + b, loc, zero, reductionDimSize, one, SmallVector{acc}, + [&](OpBuilder &b, Location loc, Value i, ValueRange accs) { + SmallVector lhsIVs(localIVs), rhsIVs(localIVs); + lhsIVs[lhsIVs.size() - 1] = i; + rhsIVs[rhsIVs.size() - 2] = i; + if (transposed) + std::swap(rhsIVs[rhsIVs.size() - 1], + rhsIVs[rhsIVs.size() - 2]); + + Value acc = accs[0]; + Value rElem = memref::LoadOp::create(b, loc, lhs, lhsIVs); + Value cElem = memref::LoadOp::create(b, loc, rhs, rhsIVs); + Value x = arith::MulFOp::create(b, loc, rElem, cElem); + x = arith::AddFOp::create(b, loc, x, acc); + + scf::YieldOp::create(b, loc, x); + }) ->getResult(0); - b.create(loc, sum, output, localIVs); + memref::StoreOp::create(b, loc, sum, output, localIVs); }); } @@ -218,23 +218,23 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b, auto keySizes = keyType.getShape(); Type elementType = queryType.getElementType(); - Value zeroF = b.create(loc, elementType, - b.getFloatAttr(elementType, 0.0)); - Value negInfF = b.create( - loc, elementType, + Value zeroF = arith::ConstantOp::create(b, loc, elementType, + b.getFloatAttr(elementType, 0.0)); + Value negInfF = arith::ConstantOp::create( + b, loc, elementType, b.getFloatAttr(elementType, -std::numeric_limits::infinity())); // TODO: This needs to be fixed, it assumes everything is dynamic however if // any shapes are static the `memref.alloc` generated is illegal. SmallVector queryDynSizes, keyDynSizes, valueDynSizes, outputDynSizes; for (auto i = 0; i < queryRank; i++) - queryDynSizes.push_back(b.create(loc, query, i)); + queryDynSizes.push_back(memref::DimOp::create(b, loc, query, i)); for (auto i = 0; i < keyRank; i++) - keyDynSizes.push_back(b.create(loc, key, i)); + keyDynSizes.push_back(memref::DimOp::create(b, loc, key, i)); for (auto i = 0; i < valueRank; i++) - valueDynSizes.push_back(b.create(loc, value, i)); + valueDynSizes.push_back(memref::DimOp::create(b, loc, value, i)); for (auto i = 0; i < queryRank; i++) - outputDynSizes.push_back(b.create(loc, output, i)); + outputDynSizes.push_back(memref::DimOp::create(b, loc, output, i)); // weight = query @ key auto weightRank = queryRank; @@ -252,145 +252,146 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b, weightFilteredDynSizes.push_back(weightDynSizes[i]); Value weight = - b.create(loc, weightType, weightFilteredDynSizes); + memref::AllocOp::create(b, loc, weightType, weightFilteredDynSizes); matmul(b, loc, query, queryDynSizes, key, keyDynSizes, weight, weightDynSizes, /*transposed=*/true); // weight = softmax(weight) Value dim = weightDynSizes[weightRank - 1]; - Value scaleFactor = b.create( - loc, b.create( - loc, elementType, - b.create(loc, b.getI32Type(), - queryDynSizes[queryRank - 1]))); + Value scaleFactor = math::SqrtOp::create( + b, loc, + arith::UIToFPOp::create( + b, loc, elementType, + arith::IndexCastUIOp::create(b, loc, b.getI32Type(), + queryDynSizes[queryRank - 1]))); // weight = (weight - max(weight)) / math.sqrt(querySizes[-1]) - Value one = b.create(loc, 1); - Value zero = b.create(loc, 0); - b.create( - loc, SmallVector(weightRank, zero), weightDynSizes, + Value one = arith::ConstantIndexOp::create(b, loc, 1); + Value zero = arith::ConstantIndexOp::create(b, loc, 0); + scf::ParallelOp::create( + b, loc, SmallVector(weightRank, zero), weightDynSizes, SmallVector(weightRank, one), [&](OpBuilder &b, Location loc, ValueRange localIVs) { - Value x = b.create(loc, weight, localIVs); - x = b.create(loc, x, scaleFactor); - b.create(loc, x, weight, localIVs); + Value x = memref::LoadOp::create(b, loc, weight, localIVs); + x = arith::DivFOp::create(b, loc, x, scaleFactor); + memref::StoreOp::create(b, loc, x, weight, localIVs); }); // Apply mask to weights if mask is given if (mask) { - b.create( - loc, SmallVector(weightRank, zero), weightDynSizes, + scf::ParallelOp::create( + b, loc, SmallVector(weightRank, zero), weightDynSizes, SmallVector(weightRank, one), [&](OpBuilder &b, Location loc, ValueRange localIVs) { - Value weightValue = b.create(loc, weight, localIVs); - Value maskValue = b.create(loc, mask, localIVs); + Value weightValue = memref::LoadOp::create(b, loc, weight, localIVs); + Value maskValue = memref::LoadOp::create(b, loc, mask, localIVs); if (maskType.getElementType().isInteger(1)) { maskValue = - b.create(loc, maskValue, zeroF, negInfF); + arith::SelectOp::create(b, loc, maskValue, zeroF, negInfF); } Value maskedWeight = - b.create(loc, weightValue, maskValue); - b.create(loc, maskedWeight, weight, localIVs); + arith::AddFOp::create(b, loc, weightValue, maskValue); + memref::StoreOp::create(b, loc, maskedWeight, weight, localIVs); }); } // calculate max(weight) - Value init = b.create(loc, weight, - SmallVector(weightRank, zero)); + Value init = memref::LoadOp::create(b, loc, weight, + SmallVector(weightRank, zero)); Value globalMax = - b.create( - loc, SmallVector(weightRank, zero), weightDynSizes, - SmallVector(weightRank, one), init, - [&](OpBuilder &b, Location loc, ValueRange localIVs, - ValueRange accs) { - auto reduceOp = b.create(loc, init); - // Build reduce body. - Block &reductionBody = reduceOp.getReductions()[0].front(); - auto bodyBuilder = OpBuilder::atBlockEnd(&reductionBody); - Value acc = reductionBody.getArgument(0); - Value x = - bodyBuilder.create(loc, weight, localIVs); - Value max = bodyBuilder.create(loc, x, acc); - bodyBuilder.create(loc, max); - }) + scf::ParallelOp::create( + b, loc, SmallVector(weightRank, zero), weightDynSizes, + SmallVector(weightRank, one), init, + [&](OpBuilder &b, Location loc, ValueRange localIVs, + ValueRange accs) { + auto reduceOp = scf::ReduceOp::create(b, loc, init); + // Build reduce body. + Block &reductionBody = reduceOp.getReductions()[0].front(); + auto bodyBuilder = OpBuilder::atBlockEnd(&reductionBody); + Value acc = reductionBody.getArgument(0); + Value x = + memref::LoadOp::create(bodyBuilder, loc, weight, localIVs); + Value max = arith::MaximumFOp::create(bodyBuilder, loc, x, acc); + scf::ReduceReturnOp::create(bodyBuilder, loc, max); + }) .getResult(0); // weight = (weight - max(weight)) / math.sqrt(querySizes[-1]) - b.create( - loc, SmallVector(weightRank, zero), weightDynSizes, + scf::ParallelOp::create( + b, loc, SmallVector(weightRank, zero), weightDynSizes, SmallVector(weightRank, one), [&](OpBuilder &b, Location loc, ValueRange localIVs) { - Value x = b.create(loc, weight, localIVs); - x = b.create(loc, x, globalMax); - b.create(loc, x, weight, localIVs); + Value x = memref::LoadOp::create(b, loc, weight, localIVs); + x = arith::SubFOp::create(b, loc, x, globalMax); + memref::StoreOp::create(b, loc, x, weight, localIVs); }); // calculate exp(weight) SmallVector min(weightRank, zero), max(weightDynSizes.begin(), weightDynSizes.end()), steps(weightRank, one); - b.create( - loc, min, max, steps, + scf::ParallelOp::create( + b, loc, min, max, steps, [&](OpBuilder &b, Location loc, ValueRange localIVs) { - Value x = b.create(loc, weight, localIVs); - x = b.create(loc, x); - b.create(loc, x, weight, localIVs); + Value x = memref::LoadOp::create(b, loc, weight, localIVs); + x = math::ExpOp::create(b, loc, x); + memref::StoreOp::create(b, loc, x, weight, localIVs); }); llvm::SmallVector expWeightDynDims(weightFilteredDynSizes); if (weightSizes.back() == ShapedType::kDynamic) expWeightDynDims.resize(expWeightDynDims.size() - 1); - Value expWeightSum = b.create( - loc, + Value expWeightSum = memref::AllocOp::create( + b, loc, MemRefType::get( SmallVector(weightSizes.begin(), weightSizes.end() - 1), elementType), expWeightDynDims); - b.create( - loc, SmallVector(weightRank - 1, zero), + scf::ParallelOp::create( + b, loc, SmallVector(weightRank - 1, zero), SmallVector{weightDynSizes.begin(), weightDynSizes.end() - 1}, SmallVector(weightRank - 1, one), [&](OpBuilder &b, Location loc, ValueRange localIVs) { - b.create(loc, zeroF, expWeightSum, localIVs); + memref::StoreOp::create(b, loc, zeroF, expWeightSum, localIVs); }); // Loop over all dims but -1 - b.create( - loc, SmallVector(weightRank - 1, zero), + scf::ParallelOp::create( + b, loc, SmallVector(weightRank - 1, zero), SmallVector(weightDynSizes.begin(), weightDynSizes.end() - 1), SmallVector(weightRank - 1, one), [&](OpBuilder &b, Location loc, ValueRange outsideDims) { // Sum over last dim - b.create( - loc, zero, dim, one, + scf::ParallelOp::create( + b, loc, zero, dim, one, [&](OpBuilder &b, Location loc, ValueRange localIVs) { SmallVector coords(outsideDims); coords.push_back(localIVs[0]); Value x = - b.create(loc, expWeightSum, outsideDims); - Value y = b.create(loc, weight, coords); - Value sum = b.create(loc, x, y); - b.create(loc, sum, expWeightSum, outsideDims); + memref::LoadOp::create(b, loc, expWeightSum, outsideDims); + Value y = memref::LoadOp::create(b, loc, weight, coords); + Value sum = arith::AddFOp::create(b, loc, x, y); + memref::StoreOp::create(b, loc, sum, expWeightSum, outsideDims); }); }); // calculate exp(weight) / sum(exp(weight)) - b.create( - loc, SmallVector(weightRank, zero), + scf::ParallelOp::create( + b, loc, SmallVector(weightRank, zero), SmallVector(weightDynSizes.begin(), weightDynSizes.end()), SmallVector(weightRank, one), [&](OpBuilder &b, Location loc, ValueRange localIVs) { SmallVector sumIVs(localIVs); sumIVs.pop_back(); - Value x = b.create(loc, weight, localIVs); - Value sum = b.create(loc, expWeightSum, sumIVs); - Value divResult = b.create(loc, x, sum); + Value x = memref::LoadOp::create(b, loc, weight, localIVs); + Value sum = memref::LoadOp::create(b, loc, expWeightSum, sumIVs); + Value divResult = arith::DivFOp::create(b, loc, x, sum); // Set to 0 if sum is 0 (can occur during boolean mask / large negative // QK) - Value isSumZero = - b.create(loc, arith::CmpFPredicate::OEQ, sum, zeroF); + Value isSumZero = arith::CmpFOp::create( + b, loc, arith::CmpFPredicate::OEQ, sum, zeroF); Value result = - b.create(loc, isSumZero, zeroF, divResult); + arith::SelectOp::create(b, loc, isSumZero, zeroF, divResult); - b.create(loc, result, weight, localIVs); + memref::StoreOp::create(b, loc, result, weight, localIVs); }); // output = weight @ value @@ -463,8 +464,8 @@ SmallVector ScanOp::getIterationDomain(OpBuilder &builder) { int64_t operandRank = getOperandRank(); SmallVector loopBounds(operandRank); Location loc = getLoc(); - Value zero = builder.create(loc, 0); - Value one = builder.create(loc, 1); + Value zero = arith::ConstantIndexOp::create(builder, loc, 0); + Value one = arith::ConstantIndexOp::create(builder, loc, 1); Value source = input(); for (auto dim : llvm::seq(0, operandRank)) { loopBounds[dim].offset = zero; @@ -507,11 +508,11 @@ LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc, ValueRange ivs) { SmallVector indices, scanBlkArgs; indices.append(ivs.begin(), ivs.end()); - Value zero = b.create(loc, 0); - Value one = b.create(loc, 1); + Value zero = arith::ConstantIndexOp::create(b, loc, 0); + Value one = arith::ConstantIndexOp::create(b, loc, 1); uint64_t scanDim = getDimension(); - Value cond = b.create(loc, arith::CmpIPredicate::eq, - indices[scanDim], zero); + Value cond = arith::CmpIOp::create(b, loc, arith::CmpIPredicate::eq, + indices[scanDim], zero); bool isInclusive = getInclusive(); SmallVector accIndices; for (size_t i = 0; i < indices.size(); i++) { @@ -519,30 +520,32 @@ LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc, accIndices.push_back(indices[i]); } - auto scfIf = b.create( - loc, cond, + auto scfIf = scf::IfOp::create( + b, loc, cond, [&](OpBuilder &b, Location loc) { if (isInclusive) { - auto value = b.create(loc, input(), indices); - b.create(loc, value, output(), indices); + auto value = memref::LoadOp::create(b, loc, input(), indices); + memref::StoreOp::create(b, loc, value, output(), indices); } else { - auto value = b.create(loc, accumulator(), accIndices); - b.create(loc, value, output(), indices); + auto value = + memref::LoadOp::create(b, loc, accumulator(), accIndices); + memref::StoreOp::create(b, loc, value, output(), indices); } - b.create(loc); + scf::YieldOp::create(b, loc); }, [&](OpBuilder &b, Location loc) { SmallVector indices(ivs.begin(), ivs.end()); Value iv = indices[scanDim]; - Value ivMinusOne = b.create(loc, iv, one); + Value ivMinusOne = arith::SubIOp::create(b, loc, iv, one); indices[scanDim] = ivMinusOne; - scanBlkArgs.push_back(b.create(loc, output(), indices)); + scanBlkArgs.push_back( + memref::LoadOp::create(b, loc, output(), indices)); Value i0; if (!isInclusive) - i0 = b.create(loc, input(), indices); + i0 = memref::LoadOp::create(b, loc, input(), indices); indices[scanDim] = iv; if (isInclusive) - i0 = b.create(loc, input(), indices); + i0 = memref::LoadOp::create(b, loc, input(), indices); scanBlkArgs.push_back(i0); }); @@ -559,13 +562,13 @@ LogicalResult ScanOp::generateScalarImplementation(OpBuilder &b, Location loc, for (auto &blockOp : srcBlock.without_terminator()) { b.clone(blockOp, bvm); } - b.create( - loc, bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0)), + memref::StoreOp::create( + b, loc, bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0)), output(), indices); - b.create( - loc, bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0)), + memref::StoreOp::create( + b, loc, bvm.lookupOrDefault(srcBlock.getTerminator()->getOperand(0)), accumulator(), accIndices); - b.create(loc); + scf::YieldOp::create(b, loc); } return success(); } @@ -767,8 +770,8 @@ bool ScatterOp::payloadUsesValueFromOperand(OpOperand *opOperand) { SmallVector ScatterOp::getIterationDomain(OpBuilder &builder) { Location loc = getLoc(); - Value zero = builder.create(loc, 0); - Value one = builder.create(loc, 1); + Value zero = arith::ConstantIndexOp::create(builder, loc, 0); + Value one = arith::ConstantIndexOp::create(builder, loc, 1); SmallVector ranges; for (auto dim : llvm::seq(0, getUpdateType().getRank())) { Value ub = getDimValue(builder, loc, updates(), dim); @@ -781,7 +784,7 @@ LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b, Location loc, ValueRange ivs) { auto indexDepth = getIndexDepth(); - Value update = b.create(loc, updates(), ivs); + Value update = memref::LoadOp::create(b, loc, updates(), ivs); SmallVector starts; SmallVector loadIndices; loadIndices.push_back(ivs.front()); @@ -799,17 +802,17 @@ LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b, ArrayRef dimMap = getDimensionMap(); for (auto i : llvm::seq(0, indexDepth)) { - loadIndices.back() = b.create(loc, i); - Value idx = b.create(loc, indices(), loadIndices); - Value ret = b.create(loc, b.getIndexType(), idx); + loadIndices.back() = arith::ConstantIndexOp::create(b, loc, i); + Value idx = memref::LoadOp::create(b, loc, indices(), loadIndices); + Value ret = arith::IndexCastOp::create(b, loc, b.getIndexType(), idx); auto dim = dimMap[i]; if (starts[dim]) - ret = b.create(loc, ret, starts[dim]); + ret = arith::AddIOp::create(b, loc, ret, starts[dim]); starts[dim] = ret; } - Value init = b.create(loc, original(), starts); + Value init = memref::LoadOp::create(b, loc, original(), starts); IRMapping bvm; Block &block = getRegion().front(); @@ -820,8 +823,8 @@ LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b, } // The last op is linalg_ext.yield op. Store the operand to // destination. - b.create( - loc, bvm.lookupOrDefault(block.getTerminator()->getOperand(0)), + memref::StoreOp::create( + b, loc, bvm.lookupOrDefault(block.getTerminator()->getOperand(0)), original(), starts); return success(); } @@ -899,8 +902,8 @@ SmallVector SortOp::getIterationDomain(OpBuilder &builder) { int64_t operandRank = getOperandRank(); SmallVector loopBounds(operandRank); Location loc = getLoc(); - Value zero = builder.create(loc, 0); - Value one = builder.create(loc, 1); + Value zero = arith::ConstantIndexOp::create(builder, loc, 0); + Value one = arith::ConstantIndexOp::create(builder, loc, 1); Value source = operand(0); for (auto dim : llvm::seq(0, operandRank)) { loopBounds[dim].offset = zero; @@ -916,28 +919,28 @@ LogicalResult SortOp::generateScalarImplementation(OpBuilder &b, Location loc, SmallVector indices, sortBlkArgs; indices.append(ivs.begin(), ivs.end()); // Bubble sort innermost loop. - Value zero = b.create(loc, 0); - Value one = b.create(loc, 1); + Value zero = arith::ConstantIndexOp::create(b, loc, 0); + Value one = arith::ConstantIndexOp::create(b, loc, 1); Value ub; if (getOperandType(0).isDynamicDim(sortDim)) { - ub = b.create(loc, operand(0), sortDim); + ub = memref::DimOp::create(b, loc, operand(0), sortDim); } else { - ub = b.create( - loc, getOperandType(0).getDimSize(sortDim)); + ub = arith::ConstantIndexOp::create(b, loc, + getOperandType(0).getDimSize(sortDim)); } - ub = b.create(loc, ub, one); - auto scfFor = b.create( - loc, zero, ub, one, ValueRange{}, + ub = arith::SubIOp::create(b, loc, ub, one); + auto scfFor = scf::ForOp::create( + b, loc, zero, ub, one, ValueRange{}, [&](OpBuilder &b, Location loc, Value iv, ValueRange iters) { SmallVector indices(ivs); - Value ivPlusOne = b.create(loc, iv, one); + Value ivPlusOne = arith::AddIOp::create(b, loc, iv, one); for (auto output : getOutputOperands()) { indices[sortDim] = iv; sortBlkArgs.push_back( - b.create(loc, output->get(), indices)); + memref::LoadOp::create(b, loc, output->get(), indices)); indices[sortDim] = ivPlusOne; sortBlkArgs.push_back( - b.create(loc, output->get(), indices)); + memref::LoadOp::create(b, loc, output->get(), indices)); } }); @@ -959,30 +962,30 @@ LogicalResult SortOp::generateScalarImplementation(OpBuilder &b, Location loc, OpBuilder::InsertionGuard g(b); b.setInsertionPointToEnd(®ion.front()); - b.create( - loc, cond, + scf::IfOp::create( + b, loc, cond, [&](OpBuilder &b, Location loc) { // Do not swap the pairs if true. - b.create(loc); + scf::YieldOp::create(b, loc); }, [&](OpBuilder &b, Location loc) { // Swap the pairs if false. SmallVector indices(ivs.begin(), ivs.end()); Value ivPlusOne = - b.create(loc, scfFor.getInductionVar(), one); + arith::AddIOp::create(b, loc, scfFor.getInductionVar(), one); for (int i = 0, e = getNumOutputs(); i < e; ++i) { Value v1 = sortBlkArgs[i * 2]; Value v2 = sortBlkArgs[i * 2 + 1]; indices[sortDim] = scfFor.getInductionVar(); - b.create(loc, v2, getOutputOperand(i)->get(), - indices); + memref::StoreOp::create(b, loc, v2, getOutputOperand(i)->get(), + indices); indices[sortDim] = ivPlusOne; - b.create(loc, v1, getOutputOperand(i)->get(), - indices); + memref::StoreOp::create(b, loc, v1, getOutputOperand(i)->get(), + indices); } - b.create(loc); + scf::YieldOp::create(b, loc); }); - b.create(loc); + scf::YieldOp::create(b, loc); return success(); } @@ -1086,8 +1089,8 @@ SmallVector TopkOp::getIterationDomain(OpBuilder &builder) { int64_t operandRank = getInputRank(); SmallVector loopBounds(operandRank); Location loc = getLoc(); - Value zero = builder.create(loc, 0); - Value one = builder.create(loc, 1); + Value zero = arith::ConstantIndexOp::create(builder, loc, 0); + Value one = arith::ConstantIndexOp::create(builder, loc, 1); Value source = values(); for (auto dim : llvm::enumerate(getInputType().getShape())) { loopBounds[dim.index()].offset = zero; @@ -1101,23 +1104,23 @@ SmallVector TopkOp::getIterationDomain(OpBuilder &builder) { LogicalResult TopkOp::generateScalarImplementation(OpBuilder &b, Location loc, ValueRange ivs) { uint64_t kDim = getDimension(); - Value zero = b.create(loc, 0); - Value one = b.create(loc, 1); - Value initialValue = b.create(loc, values(), ivs); + Value zero = arith::ConstantIndexOp::create(b, loc, 0); + Value one = arith::ConstantIndexOp::create(b, loc, 1); + Value initialValue = memref::LoadOp::create(b, loc, values(), ivs); // If the indices tensor is not provided, the value index is derived from the // loop induction variables. Value initialIndex; if (indices()) { - initialIndex = b.create(loc, *indices(), ivs); + initialIndex = memref::LoadOp::create(b, loc, *indices(), ivs); } else { Value rawInitialIndex = ivs[kDim]; initialIndex = - b.create(loc, b.getI32Type(), rawInitialIndex); + arith::IndexCastOp::create(b, loc, b.getI32Type(), rawInitialIndex); } // Compute K (ub) from the selected dim of the output - Value ub = b.create(loc, outputValues(), getDimension()); + Value ub = memref::DimOp::create(b, loc, outputValues(), getDimension()); // Inner K loop functions: // Load current K value and index @@ -1128,13 +1131,13 @@ LogicalResult TopkOp::generateScalarImplementation(OpBuilder &b, Location loc, // Store new k value and index // Yield loop carry values after K selection Value kValue, kIndex; - auto scfFor = b.create( - loc, zero, ub, one, ValueRange{initialValue, initialIndex}, + auto scfFor = scf::ForOp::create( + b, loc, zero, ub, one, ValueRange{initialValue, initialIndex}, [&](OpBuilder &b, Location loc, Value iv, ValueRange loopCarryValues) { SmallVector indices(ivs); indices[kDim] = iv; - kValue = b.create(loc, outputValues(), indices); - kIndex = b.create(loc, outputIndices(), indices); + kValue = memref::LoadOp::create(b, loc, outputValues(), indices); + kIndex = memref::LoadOp::create(b, loc, outputIndices(), indices); }); SmallVector indices(ivs); @@ -1168,28 +1171,29 @@ LogicalResult TopkOp::generateScalarImplementation(OpBuilder &b, Location loc, // f(x,y) --> forwardCmpRes // f(y,x) --> reverseCmpRes // if forwardCmpRes == reverseCmpRes then select which came first - Value cmpValuesEqual = b.create( - loc, arith::CmpIPredicate::eq, forwardCmpRes, reverseCmpRes); - Value cmpFirstIndex = b.create( - loc, arith::CmpIPredicate::slt, loopCarryValues[1], kIndex); + Value cmpValuesEqual = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::eq, forwardCmpRes, reverseCmpRes); + Value cmpFirstIndex = arith::CmpIOp::create( + b, loc, arith::CmpIPredicate::slt, loopCarryValues[1], kIndex); Value combinedCmpEqRes = - b.create(loc, cmpValuesEqual, cmpFirstIndex); + arith::AndIOp::create(b, loc, cmpValuesEqual, cmpFirstIndex); // True if N > K or N came before K Value indexCmpRes = - b.create(loc, forwardCmpRes, combinedCmpEqRes); + arith::OrIOp::create(b, loc, forwardCmpRes, combinedCmpEqRes); // Select results for K based on comparisons - Value resultKValue = b.create(loc, forwardCmpRes, - loopCarryValues[0], kValue); - Value resultKIndex = - b.create(loc, indexCmpRes, loopCarryValues[1], kIndex); - b.create(loc, resultKValue, outputValues(), indices); - b.create(loc, resultKIndex, outputIndices(), indices); + Value resultKValue = arith::SelectOp::create(b, loc, forwardCmpRes, + loopCarryValues[0], kValue); + Value resultKIndex = arith::SelectOp::create(b, loc, indexCmpRes, + loopCarryValues[1], kIndex); + memref::StoreOp::create(b, loc, resultKValue, outputValues(), indices); + memref::StoreOp::create(b, loc, resultKIndex, outputIndices(), indices); // Select loop carry, opposite of K results - Value resultCarryValue = b.create( - loc, forwardCmpRes, kValue, loopCarryValues[0]); - Value resultCarryIndex = - b.create(loc, indexCmpRes, kIndex, loopCarryValues[1]); - b.create(loc, ValueRange{resultCarryValue, resultCarryIndex}); + Value resultCarryValue = arith::SelectOp::create( + b, loc, forwardCmpRes, kValue, loopCarryValues[0]); + Value resultCarryIndex = arith::SelectOp::create( + b, loc, indexCmpRes, kIndex, loopCarryValues[1]); + scf::YieldOp::create(b, loc, + ValueRange{resultCarryValue, resultCarryIndex}); } return success(); } @@ -1262,8 +1266,8 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern { Value oldResult = std::get<0>(result); Value newResult = std::get<1>(result); if (newResult.getType() != oldResult.getType()) { - replacements.push_back(rewriter.create( - op->getLoc(), oldResult.getType(), newResult)); + replacements.push_back(tensor::CastOp::create( + rewriter, op->getLoc(), oldResult.getType(), newResult)); } else { replacements.push_back(newResult); } diff --git a/lib/Dialect/TMTensor/Transforms/Bufferize.cpp b/lib/Dialect/TMTensor/Transforms/Bufferize.cpp index d82919ef8f13..ca47cdd6033a 100644 --- a/lib/Dialect/TMTensor/Transforms/Bufferize.cpp +++ b/lib/Dialect/TMTensor/Transforms/Bufferize.cpp @@ -31,9 +31,10 @@ using namespace ::mlir::torch::TMTensor; static Value cloneMemref(Location loc, Value memref, OpBuilder &b) { auto memrefType = cast(memref.getType()); - auto alloc = b.create( - loc, memref::getMixedSizes(b, loc, memref), memrefType.getElementType()); - b.create(loc, memref, alloc); + auto alloc = + memref::AllocOp::create(b, loc, memref::getMixedSizes(b, loc, memref), + memrefType.getElementType()); + memref::CopyOp::create(b, loc, memref, alloc); return alloc; } @@ -69,12 +70,12 @@ allocateBuffersForResults(Location loc, TMTensorOp tmtensorOp, // Allocate buffers for statically-shaped results. if (memrefType.hasStaticShape()) { - resultBuffers.push_back(b.create(loc, memrefType)); + resultBuffers.push_back(memref::AllocOp::create(b, loc, memrefType)); continue; } - resultBuffers.push_back(b.create( - loc, memref::getMixedSizes(b, loc, resultTensor), + resultBuffers.push_back(memref::AllocOp::create( + b, loc, memref::getMixedSizes(b, loc, resultTensor), memrefType.getElementType())); } return success(); @@ -127,7 +128,7 @@ static Value materializeToTensor(OpBuilder &builder, TensorType type, ValueRange inputs, Location loc) { assert(inputs.size() == 1); assert(isa(inputs[0].getType())); - return builder.create(loc, type, inputs[0]); + return bufferization::ToTensorOp::create(builder, loc, type, inputs[0]); } /// Converts TMTensor operations that work on tensor-type operands or results to @@ -178,7 +179,7 @@ struct TMTensorBufferizePass } if (isa(inputs[0].getType())) { // Tensor to MemRef cast. - return builder.create(loc, type, inputs[0]); + return bufferization::ToBufferOp::create(builder, loc, type, inputs[0]); } llvm_unreachable("only tensor/memref input types supported"); }); diff --git a/lib/Dialect/TMTensor/Transforms/ConvertToLoops.cpp b/lib/Dialect/TMTensor/Transforms/ConvertToLoops.cpp index 3e31bfc46394..74d539ab6d8a 100644 --- a/lib/Dialect/TMTensor/Transforms/ConvertToLoops.cpp +++ b/lib/Dialect/TMTensor/Transforms/ConvertToLoops.cpp @@ -47,13 +47,13 @@ static LogicalResult lowerToLoopsImpl(OpBuilder &builder, getValueOrCreateConstantIndexOp(builder, loc, loopRanges[loopDepth].size); Value stride = getValueOrCreateConstantIndexOp(builder, loc, loopRanges[loopDepth].stride); - builder.create( - loc, offset, size, stride, ValueRange{}, + scf::ForOp::create( + builder, loc, offset, size, stride, ValueRange{}, [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { ivs.push_back(iv); status = lowerToLoopsImpl(b, scalarLoopOp, loopRanges, loopDepth + 1, ivs); - b.create(loc); + scf::YieldOp::create(b, loc); }); return status; } diff --git a/lib/Dialect/Torch/IR/TorchDialect.cpp b/lib/Dialect/Torch/IR/TorchDialect.cpp index fdd9875229e8..001299f21530 100644 --- a/lib/Dialect/Torch/IR/TorchDialect.cpp +++ b/lib/Dialect/Torch/IR/TorchDialect.cpp @@ -147,34 +147,35 @@ Operation *TorchDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { if (auto integerType = dyn_cast(type)) - return builder.create(loc, cast(value)); + return Torch::ConstantIntOp::create(builder, loc, cast(value)); if (auto floatType = dyn_cast(type)) - return builder.create(loc, cast(value)); + return Torch::ConstantFloatOp::create(builder, loc, cast(value)); if (auto numberType = dyn_cast(type)) { if (auto floatValue = dyn_cast(value)) { - return builder.create(loc, floatValue); + return Torch::ConstantNumberOp::create(builder, loc, floatValue); } else if (auto intValue = dyn_cast(value)) { - return builder.create(loc, intValue); + return Torch::ConstantNumberOp::create(builder, loc, intValue); } } if (isa(type)) { - return builder.create(loc, cast(value)); + return Torch::ConstantBoolOp::create(builder, loc, + cast(value)); } if (isa(type)) - return builder.create(loc); + return ConstantNoneOp::create(builder, loc); if (auto stringAttr = dyn_cast(value)) - return builder.create(loc, stringAttr); + return ConstantStrOp::create(builder, loc, stringAttr); if (auto elementsAttr = dyn_cast(value)) { // Only !torch.vtensor can be constant folded. !torch.tensor has // non-trivial aliasing semantics which prevent deduplicating it. assert(isa(type) && "should be a vtensor type!"); - return builder.create(loc, elementsAttr); + return ValueTensorLiteralOp::create(builder, loc, elementsAttr); } return nullptr; diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 7c1767b723bf..b58c6f211b03 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -63,8 +63,8 @@ Value mlir::torch::Torch::adjustStaticInformation(OpBuilder &builder, // If the type is a tensor, then adjust the static information. if ((isa(type) && isa(desiredType)) || (isa(type) && isa(desiredType))) { - Value adjusted = builder.create(value.getLoc(), - desiredType, value); + Value adjusted = TensorStaticInfoCastOp::create(builder, value.getLoc(), + desiredType, value); return adjusted; } @@ -73,7 +73,7 @@ Value mlir::torch::Torch::adjustStaticInformation(OpBuilder &builder, if (isValidSubtype(type, desiredType)) { if (!userAllowsRefinement) { Value adjusted = - builder.create(value.getLoc(), desiredType, value); + DerefineOp::create(builder, value.getLoc(), desiredType, value); return adjusted; } else { return value; @@ -83,8 +83,8 @@ Value mlir::torch::Torch::adjustStaticInformation(OpBuilder &builder, // If the desiredType is subtype of type, then we assume that the desiredType // is dynamically valid, so we do an unchecked cast. if (isValidSubtype(desiredType, type)) { - Value adjusted = - builder.create(value.getLoc(), desiredType, value); + Value adjusted = PrimUncheckedCastOp::create(builder, value.getLoc(), + desiredType, value); return adjusted; } @@ -99,8 +99,8 @@ Value mlir::torch::Torch::copyTensorToType(OpBuilder &builder, Location loc, // Adjust the static information in the type to match between the original and // new types. if (!originalType.hasSameSizesAndDtype(newType)) { - tensor = builder.create( - loc, originalType.getWithSizesAndDtypeFrom(newType), tensor); + tensor = TensorStaticInfoCastOp::create( + builder, loc, originalType.getWithSizesAndDtypeFrom(newType), tensor); } // Unless both the original and new types are both value tensors, we end @@ -108,9 +108,9 @@ Value mlir::torch::Torch::copyTensorToType(OpBuilder &builder, Location loc, // domains. If both the original and new types are both non-value tensors, // then we do the copy by going to a value tensor and back. if (isa(tensor.getType())) - tensor = builder.create(loc, tensor); + tensor = CopyToValueTensorOp::create(builder, loc, tensor); if (isa(newType)) - tensor = builder.create(loc, tensor); + tensor = CopyToNonValueTensorOp::create(builder, loc, tensor); return tensor; } @@ -182,13 +182,13 @@ static Value getScalarIntValue(Value input, Location loc, if (inputDtype.isInteger(64)) { auto val = cast(valueTensorLiteralOp.getValue()) .getSplatValue(); - return rewriter.create( - loc, rewriter.getI64IntegerAttr(val)); + return Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(val)); } else { auto val = cast(valueTensorLiteralOp.getValue()) .getSplatValue(); - return rewriter.create( - loc, rewriter.getI64IntegerAttr(val)); + return Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(val)); } } else if (auto primNumToTensorScalarOp = input.getDefiningOp()) { @@ -223,8 +223,8 @@ static Value getScalarFloatValue(Value input, Location loc, auto val = cast(valueTensorLiteralOp.getValue()) .getSplatValue() .getValueAsDouble(); - return rewriter.create( - loc, rewriter.getF64FloatAttr(val)); + return Torch::ConstantFloatOp::create(rewriter, loc, + rewriter.getF64FloatAttr(val)); } else if (auto primNumToTensorScalarOp = input.getDefiningOp()) { return primNumToTensorScalarOp.getA(); @@ -553,8 +553,8 @@ void PrimIfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, continue; newResultTypes.push_back(op->getResult(i).getType()); } - auto newIf = rewriter.create(op->getLoc(), newResultTypes, - op.getCondition()); + auto newIf = PrimIfOp::create(rewriter, op->getLoc(), newResultTypes, + op.getCondition()); rewriter.inlineRegionBefore(op.getThenRegion(), newIf.getThenRegion(), newIf.getThenRegion().end()); rewriter.inlineRegionBefore(op.getElseRegion(), newIf.getElseRegion(), @@ -1076,14 +1076,14 @@ void AtenToDtypeLayoutOp::getCanonicalizationPatterns( if (isa(op.getDevice().getType())) { // The device arg is `none`. Rewrite to to.dtype. - AtenToDtypeOp toDtype = rewriter.create( - op.getLoc(), op.getType(), op.getSelf(), op.getDtype(), + AtenToDtypeOp toDtype = AtenToDtypeOp::create( + rewriter, op.getLoc(), op.getType(), op.getSelf(), op.getDtype(), op.getNonBlocking(), op.getCopy(), op.getMemoryFormat()); rewriter.replaceOp(op, toDtype->getResults()); } else { // The device arg is not `none`. Rewrite to to.device. - AtenToDeviceOp toDevice = rewriter.create( - op.getLoc(), op.getType(), op.getSelf(), op.getDevice(), + AtenToDeviceOp toDevice = AtenToDeviceOp::create( + rewriter, op.getLoc(), op.getType(), op.getSelf(), op.getDevice(), op.getDtype(), op.getNonBlocking(), op.getCopy(), op.getMemoryFormat()); rewriter.replaceOp(op, toDevice->getResults()); @@ -1103,8 +1103,8 @@ void AtenToOtherOp::getCanonicalizationPatterns(RewritePatternSet &patterns, patterns.add(+[](AtenToOtherOp op, PatternRewriter &rewriter) { auto lhs = op.getSelf(); auto rhs = op.getOther(); - auto getRhsDevice = rewriter.create(op.getLoc(), rhs); - auto getRhsDtype = rewriter.create(op.getLoc(), rhs); + auto getRhsDevice = PrimDeviceOp::create(rewriter, op.getLoc(), rhs); + auto getRhsDtype = PrimDtypeOp::create(rewriter, op.getLoc(), rhs); rewriter.replaceOpWithNewOp( op, op.getType(), lhs, getRhsDevice.getResult(), getRhsDtype.getResult(), op.getNonBlocking(), op.getCopy(), @@ -1123,10 +1123,10 @@ void Aten_CastFloatOp::getCanonicalizationPatterns(RewritePatternSet &patterns, patterns.add(+[](Aten_CastFloatOp op, PatternRewriter &rewriter) { auto self = op.getSelf(); auto loc = op.getLoc(); - Value constNone = rewriter.create(loc); - Value f32Type = rewriter.create( - loc, (int)torch_upstream::ScalarType::Float); - Value constFalse = rewriter.create(loc, false); + Value constNone = ConstantNoneOp::create(rewriter, loc); + Value f32Type = ConstantIntOp::create( + rewriter, loc, (int)torch_upstream::ScalarType::Float); + Value constFalse = ConstantBoolOp::create(rewriter, loc, false); rewriter.replaceOpWithNewOp(op, op.getType(), self, f32Type, op.getNonBlocking(), constFalse, constNone); @@ -1144,10 +1144,10 @@ void Aten_CastLongOp::getCanonicalizationPatterns(RewritePatternSet &patterns, patterns.add(+[](Aten_CastLongOp op, PatternRewriter &rewriter) { auto self = op.getSelf(); auto loc = op.getLoc(); - Value constNone = rewriter.create(loc); - Value longType = rewriter.create( - loc, (int)torch_upstream::ScalarType::Long); - Value constFalse = rewriter.create(loc, false); + Value constNone = ConstantNoneOp::create(rewriter, loc); + Value longType = ConstantIntOp::create( + rewriter, loc, (int)torch_upstream::ScalarType::Long); + Value constFalse = ConstantBoolOp::create(rewriter, loc, false); rewriter.replaceOpWithNewOp(op, op.getType(), self, longType, op.getNonBlocking(), constFalse, constNone); @@ -1333,15 +1333,15 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op, "only int scalar alpha is supported"); } if (isa(op)) - lhs = rewriter.create(loc, lhs, alpha); + lhs = AtenMulIntOp::create(rewriter, loc, lhs, alpha); else - rhs = rewriter.create(loc, rhs, alpha); + rhs = AtenMulIntOp::create(rewriter, loc, rhs, alpha); } if (isa(op)) { if (isa(op->getOperand(2).getType())) { // None rounding mode - Value quotient = rewriter.create(loc, lhs, rhs); + Value quotient = AtenDivOp::create(rewriter, loc, lhs, rhs); rewriter.replaceOpWithNewOp(op, outType, quotient); return success(); @@ -1352,7 +1352,7 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op, op, "only None, 'floor' or 'trunc' rounding mode is supported"); } if (roundingMode == "floor") { - Value quotient = rewriter.create(loc, lhs, rhs); + Value quotient = AtenFloordivIntOp::create(rewriter, loc, lhs, rhs); rewriter.replaceOpWithNewOp(op, outType, quotient); return success(); @@ -1372,8 +1372,8 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op, } int64_t result = (int64_t)std::trunc((double)lhsInt / rhsInt); - Value resultScalar = rewriter.create( - loc, rewriter.getI64IntegerAttr(result)); + Value resultScalar = ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(result)); rewriter.replaceOpWithNewOp(op, outType, resultScalar); return success(); @@ -1385,13 +1385,13 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op, Value result; // Other Add/Sub/Mul ops if (isa(op)) { - result = rewriter.create(loc, lhs, rhs); + result = AtenAddIntOp::create(rewriter, loc, lhs, rhs); } else if (isa(op)) { - result = rewriter.create(loc, lhs, rhs); + result = AtenSubIntOp::create(rewriter, loc, lhs, rhs); } else if (isa(op)) { - result = rewriter.create(loc, rhs, lhs); + result = AtenSubIntOp::create(rewriter, loc, rhs, lhs); } else if (isa(op)) { - result = rewriter.create(loc, lhs, rhs); + result = AtenMulIntOp::create(rewriter, loc, lhs, rhs); } rewriter.replaceOpWithNewOp(op, outType, result); return success(); @@ -2386,8 +2386,8 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, return rewriter.notifyMatchFailure(op, "all sizes not known"); SmallVector listElements; for (int64_t size : type->getSizes()) { - listElements.push_back(rewriter.create( - op->getLoc(), rewriter.getI64IntegerAttr(size))); + listElements.push_back(Torch::ConstantIntOp::create( + rewriter, op->getLoc(), rewriter.getI64IntegerAttr(size))); } rewriter.replaceOpWithNewOp( op, Torch::ListType::get(rewriter.getType()), @@ -2446,7 +2446,7 @@ void AtenUnflattenIntOp::getCanonicalizationPatterns( bool dimWasConstant = matchPattern(unflattenDim, m_TorchConstantInt(&dimAsInt)); Value self = op.getSelf(); - Value cstMOne = rewriter.create(op.getLoc(), -1); + Value cstMOne = Torch::ConstantIntOp::create(rewriter, op.getLoc(), -1); // the runtime asserts below are introduced to catch malformed unflatten ops // possibly generated from onnx IR. Value unsqueeze; @@ -2460,24 +2460,24 @@ void AtenUnflattenIntOp::getCanonicalizationPatterns( // check if the remaining size value is either -1 or equal to original // size at dim Value selfSizeAtDim = - rewriter.create(op.getLoc(), self, unflattenDim); - Value isSameSize = rewriter.create( - op.getLoc(), selfSizeAtDim, sizeValues[1]); + AtenSizeIntOp::create(rewriter, op.getLoc(), self, unflattenDim); + Value isSameSize = AtenEqIntOp::create(rewriter, op.getLoc(), + selfSizeAtDim, sizeValues[1]); Value isMinusOne = - rewriter.create(op.getLoc(), cstMOne, sizeValues[1]); - Value isMOneOrSameSize = rewriter.create( - op.getLoc(), isMinusOne, isSameSize); - rewriter.create( - op.getLoc(), isMOneOrSameSize, + AtenEqIntOp::create(rewriter, op.getLoc(), cstMOne, sizeValues[1]); + Value isMOneOrSameSize = Aten__Or__BoolOp::create(rewriter, op.getLoc(), + isMinusOne, isSameSize); + Torch::RuntimeAssertOp::create( + rewriter, op.getLoc(), isMOneOrSameSize, rewriter.getStringAttr("unflatten sizes must be compatible")); } if (dim1 == 1) { // unsqueeze at dim + 1 Value dimPlusOne; if (!dimWasConstant) { - Value cstOne = rewriter.create(op.getLoc(), 1); + Value cstOne = Torch::ConstantIntOp::create(rewriter, op.getLoc(), 1); dimPlusOne = - rewriter.create(op.getLoc(), unflattenDim, cstOne); + AtenAddIntOp::create(rewriter, op.getLoc(), unflattenDim, cstOne); } else { // If dim was constant, creating an AtenAddIntOp will make // Torch::unsqueezeTensor() interpret it as still not being a constant, @@ -2486,8 +2486,8 @@ void AtenUnflattenIntOp::getCanonicalizationPatterns( // failure, when AtenUnsqueezeOp is in a later pass converted to // ExpandShapeOp, which is bound to fail shape inference in MLIR if // output dims are dynamic. - dimPlusOne = rewriter.create( - op.getLoc(), rewriter.getI64IntegerAttr(dimAsInt + 1)); + dimPlusOne = Torch::ConstantIntOp::create( + rewriter, op.getLoc(), rewriter.getI64IntegerAttr(dimAsInt + 1)); } FailureOr maybeUnsqueeze = Torch::unsqueezeTensor(rewriter, op, self, dimPlusOne); @@ -2497,15 +2497,15 @@ void AtenUnflattenIntOp::getCanonicalizationPatterns( // check if the remaining size value is either -1 or equal to original // size at dim Value selfSizeAtDim = - rewriter.create(op.getLoc(), self, unflattenDim); - Value isSameSize = rewriter.create( - op.getLoc(), selfSizeAtDim, sizeValues[0]); + AtenSizeIntOp::create(rewriter, op.getLoc(), self, unflattenDim); + Value isSameSize = AtenEqIntOp::create(rewriter, op.getLoc(), + selfSizeAtDim, sizeValues[0]); Value isMinusOne = - rewriter.create(op.getLoc(), cstMOne, sizeValues[0]); - Value isMOneOrSameSize = rewriter.create( - op.getLoc(), isMinusOne, isSameSize); - rewriter.create( - op.getLoc(), isMOneOrSameSize, + AtenEqIntOp::create(rewriter, op.getLoc(), cstMOne, sizeValues[0]); + Value isMOneOrSameSize = Aten__Or__BoolOp::create(rewriter, op.getLoc(), + isMinusOne, isSameSize); + Torch::RuntimeAssertOp::create( + rewriter, op.getLoc(), isMOneOrSameSize, rewriter.getStringAttr("unflatten sizes must be compatible")); } rewriter.replaceOpWithNewOp(op, op.getType(), @@ -2546,7 +2546,7 @@ void AtenPowTensorScalarOp::getCanonicalizationPatterns( if (expValue != static_cast(truncValue)) return failure(); Value IRScalar = - rewriter.create(op.getLoc(), truncValue); + Torch::ConstantIntOp::create(rewriter, op.getLoc(), truncValue); op->setOperand(1, IRScalar); return success(); }); @@ -2573,12 +2573,12 @@ void AtenPowTensorTensorOp::getCanonicalizationPatterns( return failure(); Value IRScalar; if (intAttr) - IRScalar = rewriter.create( - op.getLoc(), getIntAttrAsSigned(intAttr)); + IRScalar = Torch::ConstantIntOp::create(rewriter, op.getLoc(), + getIntAttrAsSigned(intAttr)); if (floatAttr) { double expValue = floatAttr.getValueAsDouble(); - IRScalar = rewriter.create(op.getLoc(), - APFloat(expValue)); + IRScalar = Torch::ConstantFloatOp::create(rewriter, op.getLoc(), + APFloat(expValue)); } rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), IRScalar); @@ -3053,10 +3053,11 @@ void AtenSortIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns, SmallVector sortedListElements; for (int64_t elem : listElements) - sortedListElements.push_back(rewriter.create( - op->getLoc(), rewriter.getI64IntegerAttr(elem))); - Value result = rewriter.create( - op->getLoc(), Torch::ListType::get(rewriter.getType()), + sortedListElements.push_back(Torch::ConstantIntOp::create( + rewriter, op->getLoc(), rewriter.getI64IntegerAttr(elem))); + Value result = Torch::PrimListConstructOp::create( + rewriter, op->getLoc(), + Torch::ListType::get(rewriter.getType()), sortedListElements); op.getSelf().replaceAllUsesWith(result); @@ -3387,9 +3388,9 @@ void Torch::ConstantNumberOp::getCanonicalizationPatterns( Value constValue; Attribute value = op.getValueAttr(); if (auto floatValue = dyn_cast(value)) { - constValue = rewriter.create(loc, floatValue); + constValue = Torch::ConstantFloatOp::create(rewriter, loc, floatValue); } else if (auto intValue = dyn_cast(value)) { - constValue = rewriter.create(loc, intValue); + constValue = Torch::ConstantIntOp::create(rewriter, loc, intValue); } else { return failure(); } @@ -3479,8 +3480,8 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns( void AtenMeshgridOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenMeshgridOp op, PatternRewriter &rewriter) { - Value constIndexing = rewriter.create( - op->getLoc(), rewriter.getStringAttr("ij")); + Value constIndexing = Torch::ConstantStrOp::create( + rewriter, op->getLoc(), rewriter.getStringAttr("ij")); rewriter.replaceOpWithNewOp( op, op->getResultTypes(), op.getTensors(), constIndexing); return success(); @@ -3666,8 +3667,8 @@ void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns, Value replacement = tupleConstruct.getElements()[i]; if (replacement.getType() != op.getType()) { if (isa(op.getType())) { - replacement = rewriter.create( - op.getLoc(), op.getType(), replacement); + replacement = Torch::TensorStaticInfoCastOp::create( + rewriter, op.getLoc(), op.getType(), replacement); } else { return failure(); } @@ -3740,8 +3741,8 @@ void PrimListUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns, for (int i = 0, s = op->getNumResults(); i < s; ++i) { auto element = listConstruct.getElements()[i]; if (element.getType() != op->getResult(i).getType()) { - element = rewriter.create( - op.getLoc(), op->getResult(i).getType(), element); + element = TensorStaticInfoCastOp::create( + rewriter, op.getLoc(), op->getResult(i).getType(), element); } unpacked.push_back(element); @@ -3871,9 +3872,10 @@ atenBinaryFloatOperatorFoldHelper(ArrayRef operands, void AtenStftOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenStftOp op, PatternRewriter &rewriter) { - Value falseVal = rewriter.create(op.getLoc(), false); + Value falseVal = + Torch::ConstantBoolOp::create(rewriter, op.getLoc(), false); Value padMode = - rewriter.create(op.getLoc(), "reflect"); + Torch::ConstantStrOp::create(rewriter, op.getLoc(), "reflect"); rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), op.getNFft(), op.getHopLength(), op.getWinLength(), op.getWindow(), falseVal, padMode, @@ -4154,8 +4156,8 @@ void AtenCatOp::getCanonicalizationPatterns(RewritePatternSet &patterns, if (filtered.size() == list.getNumOperands()) return failure(); - auto newlist = rewriter.create( - op.getLoc(), list.getType(), filtered); + auto newlist = PrimListConstructOp::create(rewriter, op.getLoc(), + list.getType(), filtered); rewriter.replaceOpWithNewOp(op, op.getType(), newlist, op.getDim()); return success(); @@ -4353,9 +4355,10 @@ void AtenMulIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns, return failure(); if ((prevLConstant || prevRConstant) && prevMulIntOp->hasOneUse() == 1) { - auto newConstant = rewriter.create( - op.getLoc(), rewriter.getI64IntegerAttr( - prevLConstant ? prevLhs * firstConstant + auto newConstant = Torch::ConstantIntOp::create( + rewriter, op.getLoc(), + rewriter.getI64IntegerAttr(prevLConstant + ? prevLhs * firstConstant : prevRhs * firstConstant)); rewriter.replaceOpWithNewOp( op, op.getType(), prevMulIntOp.getOperand(prevLConstant ? 1 : 0), @@ -5169,16 +5172,16 @@ void AtenWhereScalarOp::getCanonicalizationPatterns(RewritePatternSet &patterns, SmallVector dims; auto torchIntTy = rewriter.getType(); for (int i = 0, s = condTy.getSizes().size(); i < s; ++i) { - Value iv = rewriter.create( - op.getLoc(), torchIntTy, rewriter.getI64IntegerAttr(i)); - dims.push_back(rewriter.create( - op.getLoc(), torchIntTy, cond, iv)); + Value iv = Torch::ConstantIntOp::create(rewriter, op.getLoc(), torchIntTy, + rewriter.getI64IntegerAttr(i)); + dims.push_back(Torch::AtenSizeIntOp::create(rewriter, op.getLoc(), + torchIntTy, cond, iv)); } - Value dimsList = rewriter.create( - op.getLoc(), Torch::ListType::get(torchIntTy), dims); + Value dimsList = Torch::PrimListConstructOp::create( + rewriter, op.getLoc(), Torch::ListType::get(torchIntTy), dims); - Value none = rewriter.create(op.getLoc()); + Value none = Torch::ConstantNoneOp::create(rewriter, op.getLoc()); rewriter.replaceOpWithNewOp( op, op.getType(), dimsList, self, none, none, none, none); return success(); @@ -5681,8 +5684,9 @@ struct SimplifyMaxPoolWithIndices : public mlir::OpRewritePattern { op, "result1 of MaxPoolWithIndices should be unused"); } - Value result = rewriter.create::type>( - op->getLoc(), op.getResult0().getType(), op.getSelf(), + using ResultOpType = typename MaxPoolWithoutIndices::type; + Value result = ResultOpType::create( + rewriter, op->getLoc(), op.getResult0().getType(), op.getSelf(), op.getKernelSize(), op.getStride(), op.getPadding(), op.getDilation(), op.getCeilMode()); diff --git a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp index 1c30df45c11d..25a38c83627c 100644 --- a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp +++ b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp @@ -132,19 +132,20 @@ class AdjustCallingConventionForCall newOperands.push_back(operand.value()); } - func::CallOp newCall = rewriter.create( - call.getLoc(), call.getCallee(), convertedResults, newOperands); + func::CallOp newCall = + func::CallOp::create(rewriter, call.getLoc(), call.getCallee(), + convertedResults, newOperands); int newOpResultIdx = 0; SmallVector newResults; for (auto type : call.getResultTypes()) { if (isa(type)) { newResults.push_back( - rewriter.create(call.getLoc(), type)); + ConstantNoneOp::create(rewriter, call.getLoc(), type)); continue; } if (isa(type)) { - newResults.push_back(rewriter.create( - call.getLoc(), type, newCall.getResults())); + newResults.push_back(PrimTupleConstructOp::create( + rewriter, call.getLoc(), type, newCall.getResults())); continue; } newResults.push_back(newCall.getResult(newOpResultIdx++)); @@ -196,10 +197,10 @@ class AdjustCallingConventionForReturn if (auto tuple = dyn_cast(operand.getType())) { Location loc = op.getLoc(); for (auto en : llvm::enumerate(tuple.getContainedTypes())) { - auto i = rewriter.create( - loc, rewriter.getI64IntegerAttr(en.index())); - newOperands.push_back(rewriter.create( - loc, en.value(), operand, i)); + auto i = ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(en.index())); + newOperands.push_back(PrimTupleIndexOp::create( + rewriter, loc, en.value(), operand, i)); } continue; } diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index cc36ceeb953b..c9c42b43c463 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -83,23 +83,23 @@ static Type computeReductionType(PatternRewriter &rewriter, Operation *op, static Value createSumAlongDimension(PatternRewriter &rewriter, Location loc, Operation *op, Value input, Value dim, bool keepDim) { - Value dimList = rewriter.create( - loc, Torch::ListType::get(dim.getType()), dim); - Value keepDimCst = rewriter.create(loc, keepDim); - Value dtype = rewriter.create(loc); + Value dimList = PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(dim.getType()), dim); + Value keepDimCst = ConstantBoolOp::create(rewriter, loc, keepDim); + Value dtype = ConstantNoneOp::create(rewriter, loc); Type resultType = computeReductionType( rewriter, op, cast(input.getType()), dim, keepDim); if (!resultType) return nullptr; - return rewriter.create(loc, resultType, input, dimList, - keepDimCst, dtype); + return AtenSumDimIntListOp::create(rewriter, loc, resultType, input, dimList, + keepDimCst, dtype); } // Reduction function to calculate max along given `dim`. static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc, Operation *op, Value input, Value dim, bool keepDim) { - Value keepDimCst = rewriter.create(loc, keepDim); + Value keepDimCst = ConstantBoolOp::create(rewriter, loc, keepDim); BaseTensorType valueType = cast(computeReductionType( rewriter, op, cast(input.getType()), dim, keepDim)); if (!valueType) @@ -109,8 +109,8 @@ static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc, !valueType.hasSizes() ? std::optional>() : llvm::ArrayRef(valueType.getSizes()), IntegerType::get(op->getContext(), 64, IntegerType::Signed))); - return rewriter - .create(loc, valueType, indexType, input, dim, keepDimCst) + return AtenMaxDimOp::create(rewriter, loc, valueType, indexType, input, dim, + keepDimCst) .getValues(); } @@ -118,7 +118,7 @@ static Value createMaxAlongDimension(PatternRewriter &rewriter, Location loc, static Value createMinAlongDimension(PatternRewriter &rewriter, Location loc, Operation *op, Value input, Value dim, bool keepDim) { - Value keepDimCst = rewriter.create(loc, keepDim); + Value keepDimCst = ConstantBoolOp::create(rewriter, loc, keepDim); BaseTensorType valueType = cast(computeReductionType( rewriter, op, cast(input.getType()), dim, keepDim)); if (!valueType) @@ -128,8 +128,8 @@ static Value createMinAlongDimension(PatternRewriter &rewriter, Location loc, !valueType.hasSizes() ? std::optional>() : llvm::ArrayRef(valueType.getSizes()), IntegerType::get(op->getContext(), 64, IntegerType::Signed))); - return rewriter - .create(loc, valueType, indexType, input, dim, keepDimCst) + return AtenMinDimOp::create(rewriter, loc, valueType, indexType, input, dim, + keepDimCst) .getValues(); } @@ -137,9 +137,9 @@ static Value createMinAlongDimension(PatternRewriter &rewriter, Location loc, static Value createTensorSub(PatternRewriter &rewriter, Location loc, Type tensorType, Value lhs, Value rhs) { Value alpha = - rewriter.create(loc, rewriter.getF64FloatAttr(1)); + ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(1)); Value sub = - rewriter.create(loc, tensorType, lhs, rhs, alpha); + AtenSubTensorOp::create(rewriter, loc, tensorType, lhs, rhs, alpha); return sub; } @@ -155,11 +155,11 @@ static Value createSoftmaxBackwardCommonKernel(PatternRewriter &rewriter, return nullptr; auto broadcastSizeType = Torch::ListType::get(Torch::IntType::get(op->getContext())); - Value broadcastSize = rewriter.create(loc, broadcastSizeType, z); + Value broadcastSize = AtenSizeOp::create(rewriter, loc, broadcastSizeType, z); Value sumBroadcast = - rewriter.create(loc, tensorType, sum, broadcastSize); + AtenBroadcastToOp::create(rewriter, loc, tensorType, sum, broadcastSize); Value temp = - rewriter.create(loc, tensorType, y, sumBroadcast); + AtenMulTensorOp::create(rewriter, loc, tensorType, y, sumBroadcast); Value sub = createTensorSub(rewriter, loc, tensorType, x, temp); return sub; @@ -358,17 +358,17 @@ diagonalizeInputAndRewriteEquation(Location loc, PatternRewriter &rewriter, inputTy = rewriter.getType(newShape, inputTy.getDtype()); - Value zero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); + Value zero = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(0)); - Value d0Val = rewriter.create( - loc, rewriter.getI64IntegerAttr(d0)); - Value d1Val = rewriter.create( - loc, rewriter.getI64IntegerAttr(d1)); + Value d0Val = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(d0)); + Value d1Val = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(d1)); - input = rewriter.create(loc, inputTy, /*input=*/input, - /*offset=*/zero, /*dim1=*/d0Val, - /*dim2=*/d1Val); + input = AtenDiagonalOp::create(rewriter, loc, inputTy, /*input=*/input, + /*offset=*/zero, /*dim1=*/d0Val, + /*dim2=*/d1Val); // Frontmost token will have changed: d0--; @@ -404,19 +404,19 @@ static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc, for (auto i = 0; i < inputRank; ++i) { inputShapeTensor.emplace_back(rewriter.createOrFold( loc, input, - rewriter.create(loc, - rewriter.getI64IntegerAttr(i)))); + Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(i)))); } SmallVector outShapeTensor; - Value constOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value constOne = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(1)); auto dimOffset = 0; auto materializeIntFold = [&](OpFoldResult thing) { if (auto attr = dyn_cast(thing)) { - Value result = rewriter.create( - loc, cast(attr)); + Value result = Torch::ConstantIntOp::create( + rewriter, loc, cast(attr)); return result; } return cast(thing); @@ -450,14 +450,15 @@ static Value collapseDimForMatmul(PatternRewriter &rewriter, Location loc, resultShape.push_back(Torch::kUnknownSize); } - auto outShapeValue = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(input.getContext())), + auto outShapeValue = Torch::PrimListConstructOp::create( + rewriter, loc, + Torch::ListType::get(Torch::IntType::get(input.getContext())), outShapeTensor); auto outType = inputType.getWithSizesAndDtype(resultShape, inputType.getOptionalDtype()); - return rewriter.create(loc, outType, input, - outShapeValue); + return Torch::AtenReshapeOp::create(rewriter, loc, outType, input, + outShapeValue); } // classify every dim token into different categories. Note that although we @@ -546,8 +547,8 @@ static Value permuteTensorForMatmul(PatternRewriter &rewriter, Location loc, SmallVector permuteVec; auto appendDims = [&](SmallVector dimTokens) { for (auto d : dimTokens) { - permuteVec.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(dimTokenMap[d]))); + permuteVec.push_back(Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(dimTokenMap[d]))); permuteShape.push_back(inputType.getSizes()[dimTokenMap[d]]); } }; @@ -560,13 +561,14 @@ static Value permuteTensorForMatmul(PatternRewriter &rewriter, Location loc, if (isLhs) appendDims(contractingDims); - Value dstDims = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(rewriter.getContext())), + Value dstDims = Torch::PrimListConstructOp::create( + rewriter, loc, + Torch::ListType::get(Torch::IntType::get(rewriter.getContext())), permuteVec); auto outType = inputType.getWithSizesAndDtype(permuteShape, inputType.getOptionalDtype()); - return rewriter.create(loc, outType, input, dstDims); + return Torch::AtenPermuteOp::create(rewriter, loc, outType, input, dstDims); } static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc, @@ -583,8 +585,8 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc, auto materializeIntFold = [&](OpFoldResult thing) { if (auto attr = dyn_cast(thing)) { - Value result = rewriter.create( - loc, cast(attr)); + Value result = Torch::ConstantIntOp::create( + rewriter, loc, cast(attr)); return result; } return cast(thing); @@ -595,8 +597,8 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc, char d = lhsTokens[idx]; OpFoldResult lhsFold = rewriter.createOrFold( loc, lhs, - rewriter.create(loc, - rewriter.getI64IntegerAttr(idx))); + Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(idx))); lhsDimShapeMap[d] = materializeIntFold(lhsFold); } llvm::SmallDenseMap rhsDimShapeMap; @@ -604,8 +606,8 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc, char d = rhsTokens[idx]; OpFoldResult rhsFold = rewriter.createOrFold( loc, rhs, - rewriter.create(loc, - rewriter.getI64IntegerAttr(idx))); + Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(idx))); rhsDimShapeMap[d] = materializeIntFold(rhsFold); } @@ -672,7 +674,7 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc, // perform matmul auto outType = lhsType.getWithSizesAndDtype(outShape, outputDType); - result = rewriter.create(loc, outType, lhs, rhs); + result = Torch::AtenMatmulOp::create(rewriter, loc, outType, lhs, rhs); // generate ideal result dims. generateIdealReusltDimTokens(batchingDims, lhsOtherDims, rhsOtherDims, @@ -697,12 +699,13 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc, resultShape.push_back(Torch::kUnknownSize); } - auto outResultShape = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(lhs.getContext())), + auto outResultShape = Torch::PrimListConstructOp::create( + rewriter, loc, + Torch::ListType::get(Torch::IntType::get(lhs.getContext())), outShapeTensors); - result = rewriter.create( - loc, lhsType.getWithSizesAndDtype(resultShape, outputDType), result, - outResultShape); + result = Torch::AtenReshapeOp::create( + rewriter, loc, lhsType.getWithSizesAndDtype(resultShape, outputDType), + result, outResultShape); return success(); } @@ -729,17 +732,18 @@ static Value performLastReduceAndPermute(PatternRewriter &rewriter, if (sumDims.size() > 0) { SmallVector sumDimsTensor; for (auto d : sumDims) { - sumDimsTensor.emplace_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(d))); + sumDimsTensor.emplace_back(Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(d))); } - auto sumDimsListValue = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(rewriter.getContext())), + auto sumDimsListValue = Torch::PrimListConstructOp::create( + rewriter, loc, + Torch::ListType::get(Torch::IntType::get(rewriter.getContext())), sumDimsTensor); - auto falseValue = rewriter.create( - loc, rewriter.getBoolAttr(false)); - auto noneValue = rewriter.create(loc); - input = rewriter.create( - loc, + auto falseValue = Torch::ConstantBoolOp::create( + rewriter, loc, rewriter.getBoolAttr(false)); + auto noneValue = Torch::ConstantNoneOp::create(rewriter, loc); + input = Torch::AtenSumDimIntListOp::create( + rewriter, loc, inputType.getWithSizesAndDtype(std::nullopt, inputType.getOptionalDtype()), input, sumDimsListValue, falseValue, noneValue); @@ -747,14 +751,15 @@ static Value performLastReduceAndPermute(PatternRewriter &rewriter, SmallVector permuteDimsTensor; for (auto d : outTokens) { - permuteDimsTensor.emplace_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(inputDimToIdx[d]))); + permuteDimsTensor.emplace_back(Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(inputDimToIdx[d]))); } - auto permuteDimsListValue = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(input.getContext())), + auto permuteDimsListValue = Torch::PrimListConstructOp::create( + rewriter, loc, + Torch::ListType::get(Torch::IntType::get(input.getContext())), permuteDimsTensor); - auto out = rewriter.create(loc, outType, input, - permuteDimsListValue); + auto out = Torch::AtenPermuteOp::create(rewriter, loc, outType, input, + permuteDimsListValue); return out; } @@ -775,10 +780,10 @@ class DecomposeAtenTriuOp : public OpRewritePattern { } Value cstZero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0)); Value cstOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); - Value none = rewriter.create(loc); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); + Value none = ConstantNoneOp::create(rewriter, loc); Value rowSize = getTensorDimSize(rewriter, input, -2); Value colSize = getTensorDimSize(rewriter, input, -1); @@ -789,13 +794,13 @@ class DecomposeAtenTriuOp : public OpRewritePattern { auto colArrangeType = getTensorTypeFromShapeValues({colSize}, si64Type); Value rowArange = - rewriter.create(loc, rowArrangeType, rowSize, - /*dtype=*/int64DtypeInt, /*layout=*/none, - /*device=*/none, /*pin_memory=*/none); + AtenArangeOp::create(rewriter, loc, rowArrangeType, rowSize, + /*dtype=*/int64DtypeInt, /*layout=*/none, + /*device=*/none, /*pin_memory=*/none); Value colArange = - rewriter.create(loc, colArrangeType, colSize, - /*dtype=*/int64DtypeInt, /*layout=*/none, - /*device=*/none, /*pin_memory=*/none); + AtenArangeOp::create(rewriter, loc, colArrangeType, colSize, + /*dtype=*/int64DtypeInt, /*layout=*/none, + /*device=*/none, /*pin_memory=*/none); auto unsqueezeRowArangeInfo = unsqueezeTensor(rewriter, op, rowArange, cstOne); @@ -810,14 +815,15 @@ class DecomposeAtenTriuOp : public OpRewritePattern { Value unsqueezeRowArange = unsqueezeRowArangeInfo.value(); Value unsqueezeColArange = unsqueezeColArangeInfo.value(); - Value unsqueezeRowArangePlusDiagonal = rewriter.create( - loc, unsqueezeRowArange.getType(), unsqueezeRowArange, op.getDiagonal(), - cstOne); + Value unsqueezeRowArangePlusDiagonal = + AtenAddScalarOp::create(rewriter, loc, unsqueezeRowArange.getType(), + unsqueezeRowArange, op.getDiagonal(), cstOne); auto boolType = rewriter.getI1Type(); auto condType = getTensorTypeFromShapeValues({rowSize, colSize}, boolType); - Value condTensor = rewriter.create( - loc, condType, unsqueezeColArange, unsqueezeRowArangePlusDiagonal); + Value condTensor = + AtenGeTensorOp::create(rewriter, loc, condType, unsqueezeColArange, + unsqueezeRowArangePlusDiagonal); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), condTensor, input, cstZero); @@ -960,14 +966,14 @@ class DecomposeAtenTriuIndicesOp : public OpRewritePattern { return rewriter.notifyMatchFailure(op, "dtype is undefined"); // Constants - Value cstZero = rewriter.create(loc, 0); - Value cstOne = rewriter.create(loc, 1); - Value cstTwo = rewriter.create(loc, 2); - Value cstFalse = rewriter.create(loc, false); - Value cstMinusZeroPointFive = rewriter.create( - loc, rewriter.getF64FloatAttr(-0.5)); - Value cstMinusTwoFloat = rewriter.create( - loc, rewriter.getF64FloatAttr(-2.0)); + Value cstZero = Torch::ConstantIntOp::create(rewriter, loc, 0); + Value cstOne = Torch::ConstantIntOp::create(rewriter, loc, 1); + Value cstTwo = Torch::ConstantIntOp::create(rewriter, loc, 2); + Value cstFalse = ConstantBoolOp::create(rewriter, loc, false); + Value cstMinusZeroPointFive = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(-0.5)); + Value cstMinusTwoFloat = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(-2.0)); // Calculte trapezoidSize, rectangleSize and mFirstRow std::tuple triuSizes = @@ -978,12 +984,12 @@ class DecomposeAtenTriuIndicesOp : public OpRewritePattern { int64_t mFirstRowInt = std::get<2>(triuSizes); // Create const int Values from ints - Value trapezoidSize = rewriter.create( - loc, rewriter.getI64IntegerAttr(trapezoidSizeInt)); - Value rectangleSize = rewriter.create( - loc, rewriter.getI64IntegerAttr(rectangleSizeInt)); - Value mFirstRow = rewriter.create( - loc, rewriter.getI64IntegerAttr(mFirstRowInt)); + Value trapezoidSize = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(trapezoidSizeInt)); + Value rectangleSize = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(rectangleSizeInt)); + Value mFirstRow = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(mFirstRowInt)); // Calculte column offset Value colOffset = (offsetInt > 0) ? offset : cstZero; @@ -991,73 +997,71 @@ class DecomposeAtenTriuIndicesOp : public OpRewritePattern { // Calculate indices for top rectangle auto arrangeType = getTensorTypeFromShapeValues({rectangleSize}, *dtypeType); - Value xs2 = - rewriter.create(loc, arrangeType, rectangleSize, - /*dtype=*/dtype, /*layout=*/layout, - /*device=*/device, - /*pin_memory=*/pinMemory); + Value xs2 = AtenArangeOp::create(rewriter, loc, arrangeType, rectangleSize, + /*dtype=*/dtype, /*layout=*/layout, + /*device=*/device, + /*pin_memory=*/pinMemory); // Calculate row_indices2 and column_idices 2 Value rowInds2 = - rewriter.create(loc, xs2.getType(), xs2, col); + AtenFloorDivideScalarOp::create(rewriter, loc, xs2.getType(), xs2, col); Value colInds2 = - rewriter.create(loc, xs2.getType(), xs2, col); + AtenRemainderScalarOp::create(rewriter, loc, xs2.getType(), xs2, col); // Bottom trapezoid auto f64DtypeInt = getDtypeIntValueForType(rewriter, loc, rewriter.getF64Type()); arrangeType = getTensorTypeFromShapeValues({trapezoidSize}, rewriter.getF64Type()); - Value xs1 = - rewriter.create(loc, arrangeType, trapezoidSize, - /*dtype=*/f64DtypeInt, /*layout=*/layout, - /*device=*/device, - /*pin_memory=*/pinMemory); + Value xs1 = AtenArangeOp::create(rewriter, loc, arrangeType, trapezoidSize, + /*dtype=*/f64DtypeInt, /*layout=*/layout, + /*device=*/device, + /*pin_memory=*/pinMemory); // b = -0.5 - m_first_row - Value mFirstRowFloat = rewriter.create( - loc, rewriter.getF64FloatAttr(mFirstRowInt)); - Value b = rewriter.create(loc, cstMinusZeroPointFive, - mFirstRowFloat); + Value mFirstRowFloat = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(mFirstRowInt)); + Value b = AtenSubFloatOp::create(rewriter, loc, cstMinusZeroPointFive, + mFirstRowFloat); // Implements this piece of code: row_inds1 = torch.floor(-b - torch.sqrt(b // * b - 2 * xs1)) - Value bSquare = rewriter.create(loc, b, b); + Value bSquare = AtenMulFloatOp::create(rewriter, loc, b, b); - Value twoTimesXs1 = rewriter.create(loc, xs1.getType(), - xs1, cstMinusTwoFloat); - Value sqrtInput = rewriter.create( - loc, twoTimesXs1.getType(), twoTimesXs1, bSquare, cstOne); + Value twoTimesXs1 = AtenMulScalarOp::create(rewriter, loc, xs1.getType(), + xs1, cstMinusTwoFloat); + Value sqrtInput = AtenAddScalarOp::create( + rewriter, loc, twoTimesXs1.getType(), twoTimesXs1, bSquare, cstOne); Value sqrt = - rewriter.create(loc, sqrtInput.getType(), sqrtInput); - Value negativeSqrt = rewriter.create(loc, sqrt.getType(), sqrt); + AtenSqrtOp::create(rewriter, loc, sqrtInput.getType(), sqrtInput); + Value negativeSqrt = AtenNegOp::create(rewriter, loc, sqrt.getType(), sqrt); - Value rowInds1 = rewriter.create( - loc, negativeSqrt.getType(), negativeSqrt, b, cstOne); - rowInds1 = rewriter.create(loc, rowInds1.getType(), rowInds1); + Value rowInds1 = AtenSubScalarOp::create( + rewriter, loc, negativeSqrt.getType(), negativeSqrt, b, cstOne); + rowInds1 = AtenFloorOp::create(rewriter, loc, rowInds1.getType(), rowInds1); // Implements this piece of code: col_inds1 = torch.floor(xs1 - ((2 * // m_first_row - 1 - row_inds1) * row_inds1) * 0.5) Value twoTimesMFirstRow = - rewriter.create(loc, cstTwo, mFirstRow); + AtenMulIntOp::create(rewriter, loc, cstTwo, mFirstRow); twoTimesMFirstRow = - rewriter.create(loc, twoTimesMFirstRow, cstOne); + AtenSubIntOp::create(rewriter, loc, twoTimesMFirstRow, cstOne); Value negativeRowInds1 = - rewriter.create(loc, rowInds1.getType(), rowInds1); + AtenNegOp::create(rewriter, loc, rowInds1.getType(), rowInds1); - negativeRowInds1 = rewriter.create( - loc, negativeRowInds1.getType(), negativeRowInds1, twoTimesMFirstRow, - cstOne); - negativeRowInds1 = rewriter.create( - loc, negativeRowInds1.getType(), negativeRowInds1, rowInds1); - negativeRowInds1 = rewriter.create( - loc, negativeRowInds1.getType(), negativeRowInds1, - cstMinusZeroPointFive); + negativeRowInds1 = + AtenAddScalarOp::create(rewriter, loc, negativeRowInds1.getType(), + negativeRowInds1, twoTimesMFirstRow, cstOne); + negativeRowInds1 = AtenMulTensorOp::create( + rewriter, loc, negativeRowInds1.getType(), negativeRowInds1, rowInds1); + negativeRowInds1 = + AtenMulScalarOp::create(rewriter, loc, negativeRowInds1.getType(), + negativeRowInds1, cstMinusZeroPointFive); - Value colInds1 = rewriter.create(loc, xs1.getType(), xs1, - negativeRowInds1, cstOne); - colInds1 = rewriter.create(loc, colInds1.getType(), colInds1); + Value colInds1 = AtenAddTensorOp::create(rewriter, loc, xs1.getType(), xs1, + negativeRowInds1, cstOne); + colInds1 = AtenFloorOp::create(rewriter, loc, colInds1.getType(), colInds1); // Convert to dtype Type int64Type = rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true); @@ -1065,31 +1069,31 @@ class DecomposeAtenTriuIndicesOp : public OpRewritePattern { auto rowInds1Type = cast(rowInds1.getType()); ArrayRef sizes = rowInds1Type.getSizes(); Type finalRowType = rowInds1Type.getWithSizesAndDtype(sizes, int64Type); - rowInds1 = rewriter.create( - loc, finalRowType, rowInds1, dtype, - /*non_blocking=*/cstFalse, /*copy=*/cstFalse, - /*memory_format=*/cstOne); + rowInds1 = + AtenToDtypeOp::create(rewriter, loc, finalRowType, rowInds1, dtype, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/cstOne); auto colInds1Type = cast(colInds1.getType()); sizes = colInds1Type.getSizes(); Type finalColType = colInds1Type.getWithSizesAndDtype(sizes, int64Type); - colInds1 = rewriter.create( - loc, finalColType, colInds1, dtype, - /*non_blocking=*/cstFalse, /*copy=*/cstFalse, - /*memory_format=*/cstOne); + colInds1 = + AtenToDtypeOp::create(rewriter, loc, finalColType, colInds1, dtype, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/cstOne); // Final calculation for row and col indices if (colInt) { - Value rectangleSizeDivCol = - rewriter.create(loc, rectangleSizeInt / colInt); + Value rectangleSizeDivCol = Torch::ConstantIntOp::create( + rewriter, loc, rectangleSizeInt / colInt); - rowInds1 = rewriter.create( - loc, rowInds1.getType(), rowInds1, rectangleSizeDivCol, cstOne); + rowInds1 = AtenAddScalarOp::create(rewriter, loc, rowInds1.getType(), + rowInds1, rectangleSizeDivCol, cstOne); } - colInds1 = rewriter.create(loc, colInds1.getType(), - colInds1, colOffset, cstOne); + colInds1 = AtenAddScalarOp::create(rewriter, loc, colInds1.getType(), + colInds1, colOffset, cstOne); Type listElemType = cast(rowInds1.getType()) @@ -1097,23 +1101,23 @@ class DecomposeAtenTriuIndicesOp : public OpRewritePattern { /*optionalDtype=*/nullptr); Type listType = Torch::ListType::get(listElemType); - Value sequenceRow = rewriter.create( - loc, listType, SmallVector{rowInds2, rowInds1}); - Value sequenceCol = rewriter.create( - loc, listType, SmallVector{colInds2, colInds1}); + Value sequenceRow = Torch::PrimListConstructOp::create( + rewriter, loc, listType, SmallVector{rowInds2, rowInds1}); + Value sequenceCol = Torch::PrimListConstructOp::create( + rewriter, loc, listType, SmallVector{colInds2, colInds1}); // Concatenate row and col indices Type finalCatType = colInds1Type.getWithSizesAndDtype( {rectangleSizeInt + trapezoidSizeInt}, int64Type); - Value catRow = rewriter.create(loc, finalCatType, sequenceRow, - /*dim=*/cstZero); - Value catCol = rewriter.create(loc, finalCatType, sequenceCol, - /*dim=*/cstZero); + Value catRow = AtenCatOp::create(rewriter, loc, finalCatType, sequenceRow, + /*dim=*/cstZero); + Value catCol = AtenCatOp::create(rewriter, loc, finalCatType, sequenceCol, + /*dim=*/cstZero); // Make return value - Value sequence = rewriter.create( - loc, Torch::ListType::get(context, rowInds1.getType()), + Value sequence = PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(context, rowInds1.getType()), ValueRange{catRow, catCol}); Type finalStackType = colInds1Type.getWithSizesAndDtype( ArrayRef{2, rectangleSizeInt + trapezoidSizeInt}, int64Type); @@ -1166,14 +1170,14 @@ class DecomposeAtenTrilIndicesOp : public OpRewritePattern { Value pinMemory = op.getPinMemory(); // Constants - Value cstZero = rewriter.create(loc, 0); - Value cstOne = rewriter.create(loc, 1); - Value cstTwo = rewriter.create(loc, 2); - Value cstFalse = rewriter.create(loc, false); - Value cstZeroPointFive = rewriter.create( - loc, rewriter.getF64FloatAttr(0.5)); - Value cstTwoFloat = rewriter.create( - loc, rewriter.getF64FloatAttr(2.0)); + Value cstZero = Torch::ConstantIntOp::create(rewriter, loc, 0); + Value cstOne = Torch::ConstantIntOp::create(rewriter, loc, 1); + Value cstTwo = Torch::ConstantIntOp::create(rewriter, loc, 2); + Value cstFalse = ConstantBoolOp::create(rewriter, loc, false); + Value cstZeroPointFive = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(0.5)); + Value cstTwoFloat = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(2.0)); // Get int value for dtype int64_t dtypeInt; @@ -1195,66 +1199,67 @@ class DecomposeAtenTrilIndicesOp : public OpRewritePattern { int64_t mFirstRowInt = std::get<2>(triuSizes); // Create const int Values from ints - Value trapezoidSize = rewriter.create( - loc, rewriter.getI64IntegerAttr(trapezoidSizeInt)); - Value rectangleSize = rewriter.create( - loc, rewriter.getI64IntegerAttr(rectangleSizeInt)); - Value mFirstRow = rewriter.create( - loc, rewriter.getI64IntegerAttr(mFirstRowInt)); + Value trapezoidSize = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(trapezoidSizeInt)); + Value rectangleSize = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(rectangleSizeInt)); + Value mFirstRow = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(mFirstRowInt)); // Calculte column offset int64_t rowOffsetInt = (-offsetInt > 0) ? (-offsetInt) : 0; - Value rowOffset = rewriter.create(loc, rowOffsetInt); + Value rowOffset = Torch::ConstantIntOp::create(rewriter, loc, rowOffsetInt); // First we do the indices for TOP trapezoid auto f64DtypeInt = getDtypeIntValueForType(rewriter, loc, rewriter.getF64Type()); auto arrangeType = getTensorTypeFromShapeValues({trapezoidSize}, rewriter.getF64Type()); - Value xs1 = - rewriter.create(loc, arrangeType, trapezoidSize, - /*dtype=*/f64DtypeInt, /*layout=*/layout, - /*device=*/device, - /*pin_memory=*/pinMemory); + Value xs1 = AtenArangeOp::create(rewriter, loc, arrangeType, trapezoidSize, + /*dtype=*/f64DtypeInt, /*layout=*/layout, + /*device=*/device, + /*pin_memory=*/pinMemory); // b = m_first_row - 0.5 - Value mFirstRowFloat = rewriter.create( - loc, rewriter.getF64FloatAttr(mFirstRowInt)); + Value mFirstRowFloat = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(mFirstRowInt)); Value b = - rewriter.create(loc, mFirstRowFloat, cstZeroPointFive); + AtenSubFloatOp::create(rewriter, loc, mFirstRowFloat, cstZeroPointFive); // Implements this piece of code: row_inds1 = torch.floor(-b + torch.sqrt(b // * b + 2 * xs1)) - Value bSquare = rewriter.create(loc, b, b); + Value bSquare = AtenMulFloatOp::create(rewriter, loc, b, b); Value twoTimesXs1 = - rewriter.create(loc, xs1.getType(), xs1, cstTwoFloat); - Value sqrtInput = rewriter.create( - loc, twoTimesXs1.getType(), twoTimesXs1, bSquare, cstOne); + AtenMulScalarOp::create(rewriter, loc, xs1.getType(), xs1, cstTwoFloat); + Value sqrtInput = AtenAddScalarOp::create( + rewriter, loc, twoTimesXs1.getType(), twoTimesXs1, bSquare, cstOne); Value sqrt = - rewriter.create(loc, sqrtInput.getType(), sqrtInput); + AtenSqrtOp::create(rewriter, loc, sqrtInput.getType(), sqrtInput); Value rowInds1 = - rewriter.create(loc, sqrt.getType(), sqrt, b, cstOne); - rowInds1 = rewriter.create(loc, rowInds1.getType(), rowInds1); + AtenSubScalarOp::create(rewriter, loc, sqrt.getType(), sqrt, b, cstOne); + rowInds1 = AtenFloorOp::create(rewriter, loc, rowInds1.getType(), rowInds1); // Implements this piece of code: col_inds1 = torch.floor(xs1 - (2 * // m_first_row - 1 + row_inds1) * row_inds1 * 0.5) Value twoTimesMFirstRow = - rewriter.create(loc, cstTwo, mFirstRow); + AtenMulIntOp::create(rewriter, loc, cstTwo, mFirstRow); + twoTimesMFirstRow = + AtenSubIntOp::create(rewriter, loc, twoTimesMFirstRow, cstOne); + twoTimesMFirstRow = AtenAddScalarOp::create( + rewriter, loc, rowInds1.getType(), rowInds1, twoTimesMFirstRow, cstOne); + twoTimesMFirstRow = + AtenMulTensorOp::create(rewriter, loc, twoTimesMFirstRow.getType(), + twoTimesMFirstRow, rowInds1); twoTimesMFirstRow = - rewriter.create(loc, twoTimesMFirstRow, cstOne); - twoTimesMFirstRow = rewriter.create( - loc, rowInds1.getType(), rowInds1, twoTimesMFirstRow, cstOne); - twoTimesMFirstRow = rewriter.create( - loc, twoTimesMFirstRow.getType(), twoTimesMFirstRow, rowInds1); - twoTimesMFirstRow = rewriter.create( - loc, twoTimesMFirstRow.getType(), twoTimesMFirstRow, cstZeroPointFive); - - Value colInds1 = rewriter.create( - loc, xs1.getType(), xs1, twoTimesMFirstRow, cstOne); - colInds1 = rewriter.create(loc, colInds1.getType(), colInds1); + AtenMulScalarOp::create(rewriter, loc, twoTimesMFirstRow.getType(), + twoTimesMFirstRow, cstZeroPointFive); + + Value colInds1 = AtenSubTensorOp::create(rewriter, loc, xs1.getType(), xs1, + twoTimesMFirstRow, cstOne); + colInds1 = AtenFloorOp::create(rewriter, loc, colInds1.getType(), colInds1); // Convert top trapezoid indices to dtype Type int64Type = rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true); @@ -1262,41 +1267,40 @@ class DecomposeAtenTrilIndicesOp : public OpRewritePattern { auto rowInds1Type = cast(rowInds1.getType()); ArrayRef sizes = rowInds1Type.getSizes(); Type finalRowType = rowInds1Type.getWithSizesAndDtype(sizes, int64Type); - rowInds1 = rewriter.create(loc, rowInds1.getType(), - rowInds1, rowOffset, cstOne); - rowInds1 = rewriter.create( - loc, finalRowType, rowInds1, dtype, - /*non_blocking=*/cstFalse, /*copy=*/cstFalse, - /*memory_format=*/cstOne); + rowInds1 = AtenAddScalarOp::create(rewriter, loc, rowInds1.getType(), + rowInds1, rowOffset, cstOne); + rowInds1 = + AtenToDtypeOp::create(rewriter, loc, finalRowType, rowInds1, dtype, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/cstOne); auto colInds1Type = cast(colInds1.getType()); sizes = colInds1Type.getSizes(); Type finalColType = colInds1Type.getWithSizesAndDtype(sizes, int64Type); - colInds1 = rewriter.create( - loc, finalColType, colInds1, dtype, - /*non_blocking=*/cstFalse, /*copy=*/cstFalse, - /*memory_format=*/cstOne); + colInds1 = + AtenToDtypeOp::create(rewriter, loc, finalColType, colInds1, dtype, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/cstOne); // Calculate indices for BOTTOM rectangle arrangeType = getTensorTypeFromShapeValues({rectangleSize}, *dtypeType); - Value xs2 = - rewriter.create(loc, arrangeType, rectangleSize, - /*dtype=*/dtype, /*layout=*/layout, - /*device=*/device, - /*pin_memory=*/pinMemory); + Value xs2 = AtenArangeOp::create(rewriter, loc, arrangeType, rectangleSize, + /*dtype=*/dtype, /*layout=*/layout, + /*device=*/device, + /*pin_memory=*/pinMemory); // Implements this line of code: row_inds2 = xs2 // col + (col - m_first_row // + 1 + row_offset) Value rowInds2 = - rewriter.create(loc, xs2.getType(), xs2, col); + AtenFloorDivideScalarOp::create(rewriter, loc, xs2.getType(), xs2, col); int64_t addInt = colInt - mFirstRowInt + 1 + rowOffsetInt; - Value cstAdd = rewriter.create(loc, addInt); - rowInds2 = rewriter.create(loc, rowInds2.getType(), - rowInds2, cstAdd, cstOne); + Value cstAdd = Torch::ConstantIntOp::create(rewriter, loc, addInt); + rowInds2 = AtenAddScalarOp::create(rewriter, loc, rowInds2.getType(), + rowInds2, cstAdd, cstOne); // Implements this line of code: col_inds2 = xs2 % col Value colInds2 = - rewriter.create(loc, xs2.getType(), xs2, col); + AtenRemainderScalarOp::create(rewriter, loc, xs2.getType(), xs2, col); // Prepare tensors for concatenation Type listElemType = @@ -1305,23 +1309,23 @@ class DecomposeAtenTrilIndicesOp : public OpRewritePattern { /*optionalDtype=*/nullptr); Type listType = Torch::ListType::get(listElemType); - Value sequenceRow = rewriter.create( - loc, listType, SmallVector{rowInds1, rowInds2}); - Value sequenceCol = rewriter.create( - loc, listType, SmallVector{colInds1, colInds2}); + Value sequenceRow = Torch::PrimListConstructOp::create( + rewriter, loc, listType, SmallVector{rowInds1, rowInds2}); + Value sequenceCol = Torch::PrimListConstructOp::create( + rewriter, loc, listType, SmallVector{colInds1, colInds2}); // Concatenate row and col indices Type finalCatType = colInds1Type.getWithSizesAndDtype( {rectangleSizeInt + trapezoidSizeInt}, int64Type); - Value catRow = rewriter.create(loc, finalCatType, sequenceRow, - /*dim=*/cstZero); - Value catCol = rewriter.create(loc, finalCatType, sequenceCol, - /*dim=*/cstZero); + Value catRow = AtenCatOp::create(rewriter, loc, finalCatType, sequenceRow, + /*dim=*/cstZero); + Value catCol = AtenCatOp::create(rewriter, loc, finalCatType, sequenceCol, + /*dim=*/cstZero); // Make return value - stack row and col indices - Value sequence = rewriter.create( - loc, Torch::ListType::get(context, rowInds1.getType()), + Value sequence = PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(context, rowInds1.getType()), ValueRange{catRow, catCol}); Type finalStackType = colInds1Type.getWithSizesAndDtype( ArrayRef{2, rectangleSizeInt + trapezoidSizeInt}, int64Type); @@ -1358,12 +1362,13 @@ class DecomposeAtenDeg2radOp : public OpRewritePattern { } Value pi = - rewriter.create(loc, rewriter.getF64FloatAttr(M_PI)); + ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(M_PI)); Value basic = - rewriter.create(loc, rewriter.getF64FloatAttr(180.0)); + ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(180.0)); Value rad = - rewriter.create(loc, op.getType(), self, basic); - Value result = rewriter.create(loc, op.getType(), rad, pi); + AtenDivScalarOp::create(rewriter, loc, op.getType(), self, basic); + Value result = + AtenMulScalarOp::create(rewriter, loc, op.getType(), rad, pi); rewriter.replaceOp(op, result); @@ -1393,18 +1398,18 @@ class DecomposeAtenFliplrOp : public OpRewritePattern { } Location loc = op.getLoc(); - Value constI = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); + Value constI = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(1)); SmallVector dims; dims.push_back(constI); - Value flipDimList = rewriter.create( - loc, + Value flipDimList = Torch::PrimListConstructOp::create( + rewriter, loc, rewriter.getType(rewriter.getType()), dims); - Value flip = rewriter.create(loc, op.getType(), op.getSelf(), - flipDimList); + Value flip = AtenFlipOp::create(rewriter, loc, op.getType(), op.getSelf(), + flipDimList); rewriter.replaceOp(op, flip); return success(); } @@ -1432,18 +1437,18 @@ class DecomposeAtenFlipudOp : public OpRewritePattern { } Location loc = op.getLoc(); - Value constI = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); + Value constI = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(0)); SmallVector dims; dims.push_back(constI); - Value flipDimList = rewriter.create( - loc, + Value flipDimList = Torch::PrimListConstructOp::create( + rewriter, loc, rewriter.getType(rewriter.getType()), dims); - Value flip = rewriter.create(loc, op.getType(), op.getSelf(), - flipDimList); + Value flip = AtenFlipOp::create(rewriter, loc, op.getType(), op.getSelf(), + flipDimList); rewriter.replaceOp(op, flip); return success(); } @@ -1466,13 +1471,14 @@ class DecomposeAtenSizeOp : public OpRewritePattern { unsigned rank = *maybeRank; SmallVector sizes; for (unsigned i = 0; i < rank; i++) { - Value dim = rewriter.create( - loc, rewriter.getI64IntegerAttr(i)); - sizes.push_back(rewriter.create(loc, self, dim)); + Value dim = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(i)); + sizes.push_back(AtenSizeIntOp::create(rewriter, loc, self, dim)); } - Value sizeList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(context)), sizes); + Value sizeList = PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(Torch::IntType::get(context)), + sizes); rewriter.replaceOp(op, sizeList); return success(); } @@ -1498,19 +1504,20 @@ class DecomposeAtenSelectIntOp : public OpRewritePattern { // convert `start` to non-negative: start += int(start < 0) * dimSize Value zero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); - Value isNegative = rewriter.create(loc, start, zero); - isNegative = rewriter.create(loc, isNegative); - Value dimSize = rewriter.create(loc, self, dim); - Value indexOffset = rewriter.create(loc, isNegative, dimSize); - start = rewriter.create(loc, start, indexOffset); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0)); + Value isNegative = AtenLtIntOp::create(rewriter, loc, start, zero); + isNegative = AtenIntBoolOp::create(rewriter, loc, isNegative); + Value dimSize = AtenSizeIntOp::create(rewriter, loc, self, dim); + Value indexOffset = + AtenMulIntOp::create(rewriter, loc, isNegative, dimSize); + start = AtenAddIntOp::create(rewriter, loc, start, indexOffset); Value one = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); Value startPlusOne = - rewriter.create(loc, one.getType(), start, one); - Value slice = rewriter.create( - loc, + AtenAddIntOp::create(rewriter, loc, one.getType(), start, one); + Value slice = AtenSliceTensorOp::create( + rewriter, loc, computeReductionType(rewriter, op, cast(self.getType()), dim, /*keepDim=*/true), @@ -1556,17 +1563,17 @@ class DecomposePrimTolistOp : public OpRewritePattern { auto scalarTy = resultTy.getContainedType(); Value zero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0)); auto extractTy = rewriter.getType( llvm::SmallVector{1}, selfTy.getOptionalDtype()); llvm::SmallVector results; llvm::SmallVector sizes(selfTy.getSizes()); for (int64_t i = 0; i < length; ++i) { Value iv = - rewriter.create(loc, rewriter.getI64IntegerAttr(i)); - Value extract = rewriter.create( - loc, extractTy, self, /*dim=*/zero, /*index=*/iv); - Value scalar = rewriter.create(loc, scalarTy, extract); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(i)); + Value extract = AtenSelectIntOp::create(rewriter, loc, extractTy, self, + /*dim=*/zero, /*index=*/iv); + Value scalar = AtenItemOp::create(rewriter, loc, scalarTy, extract); results.push_back(scalar); } @@ -1612,15 +1619,15 @@ class DecomposeAtenSplitWithSizesOp auto intTy = rewriter.getType(); Value one = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); Value begin = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0)); llvm::SmallVector slices; llvm::SmallVector sliceSizes(sliceTy.getSizes()); int64_t defaultLength = !hasDim ? Torch::kUnknownSize : sliceSizes[dimInt]; for (auto size : splitSizes) { - Value end = rewriter.create(loc, intTy, begin, size); + Value end = AtenAddIntOp::create(rewriter, loc, intTy, begin, size); int64_t sizeInt; if (hasDim && matchPattern(size, m_TorchConstantInt(&sizeInt))) { @@ -1631,8 +1638,8 @@ class DecomposeAtenSplitWithSizesOp sliceTy = rewriter.getType(sliceSizes, sliceTy.getOptionalDtype()); - Value slice = rewriter.create( - loc, sliceTy, op.getSelf(), + Value slice = AtenSliceTensorOp::create( + rewriter, loc, sliceTy, op.getSelf(), /*dim=*/op.getDim(), /*start=*/begin, /*end=*/end, /*step=*/one); slices.push_back(slice); begin = end; @@ -1657,9 +1664,9 @@ class DecomposeAtenNarrowOp : public OpRewritePattern { Value length = op.getLength(); Value one = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); Value startPlusLength = - rewriter.create(loc, one.getType(), start, length); + AtenAddIntOp::create(rewriter, loc, one.getType(), start, length); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), op.getSelf(), /*dim=*/dim, @@ -1683,8 +1690,8 @@ class DecomposeAtenNarrowTensorOp auto *context = op.getContext(); // PyTorch makes sure that `start` param is an 0-dim integral tensor. // REF: https://pytorch.org/docs/stable/generated/torch.narrow.html. - auto start = rewriter.create( - loc, Torch::IntType::get(context), op.getStart()); + auto start = Torch::AtenScalarImplicitOp::create( + rewriter, loc, Torch::IntType::get(context), op.getStart()); rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), op.getDim(), start, op.getLength()); return success(); @@ -1709,26 +1716,27 @@ class DecomposeAtenGluOp : public OpRewritePattern { } Value zero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); - Value dimSize = rewriter.create(loc, self, dim); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0)); + Value dimSize = AtenSizeIntOp::create(rewriter, loc, self, dim); Value two = - rewriter.create(loc, rewriter.getI64IntegerAttr(2)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(2)); - Value remainder = rewriter.create(loc, dimSize, two); - Value eqOrNot = rewriter.create(loc, remainder, zero); + Value remainder = AtenRemainderIntOp::create(rewriter, loc, dimSize, two); + Value eqOrNot = AtenEqIntOp::create(rewriter, loc, remainder, zero); - rewriter.create( - loc, eqOrNot, + RuntimeAssertOp::create( + rewriter, loc, eqOrNot, rewriter.getStringAttr("AtenGluOp's dim size must be multiple of 2")); - Value splitLength = rewriter.create(loc, dimSize, two); - Value a = rewriter.create(loc, outputTy, self, dim, zero, - splitLength); - Value b = rewriter.create(loc, outputTy, self, dim, - splitLength, splitLength); + Value splitLength = AtenFloordivIntOp::create(rewriter, loc, dimSize, two); + Value a = AtenNarrowOp::create(rewriter, loc, outputTy, self, dim, zero, + splitLength); + Value b = AtenNarrowOp::create(rewriter, loc, outputTy, self, dim, + splitLength, splitLength); // a⊗σ(b) - Value sigmoidB = rewriter.create(loc, outputTy, b); - Value result = rewriter.create(loc, outputTy, a, sigmoidB); + Value sigmoidB = AtenSigmoidOp::create(rewriter, loc, outputTy, b); + Value result = + AtenMulTensorOp::create(rewriter, loc, outputTy, a, sigmoidB); rewriter.replaceOp(op, result); return success(); } @@ -1741,8 +1749,8 @@ class DecomposeAtenZeroOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenZeroOp op, PatternRewriter &rewriter) const override { - Value zero = rewriter.create(op.getLoc(), - rewriter.getI64IntegerAttr(0)); + Value zero = ConstantIntOp::create(rewriter, op.getLoc(), + rewriter.getI64IntegerAttr(0)); rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), zero); return success(); @@ -1780,7 +1788,7 @@ class DecomposeAtenEyeMOp : public OpRewritePattern { if (!outType.hasDtype()) { return rewriter.notifyMatchFailure(op, "result should have dtype"); } - Value none = rewriter.create(loc); + Value none = ConstantNoneOp::create(rewriter, loc); auto context = op.getContext(); auto int64Dtype = getDtypeIntValueForType( rewriter, loc, @@ -1803,18 +1811,19 @@ class DecomposeAtenEyeMOp : public OpRewritePattern { // prepare two unsqueezed ranges that are equal on and only on the diagonal auto rangeNSize = llvm::SmallVector({n}); Type rangeNType = outType.getWithSizesAndDtype(rangeNSize, si64Type); - Value rangeN = rewriter.create( - loc, rangeNType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none, - /*device=*/op.getDevice(), /*pin_memory=*/none); + Value rangeN = + AtenArangeOp::create(rewriter, loc, rangeNType, op.getN(), + /*dtype=*/int64Dtype, /*layout=*/none, + /*device=*/op.getDevice(), /*pin_memory=*/none); auto rangeMSize = llvm::SmallVector({m}); Type rangeMType = outType.getWithSizesAndDtype(rangeMSize, si64Type); - Value rangeM = rewriter.create( - loc, rangeMType, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none, - /*device=*/none, /*pin_memory=*/none); + Value rangeM = AtenArangeOp::create(rewriter, loc, rangeMType, op.getM(), + /*dtype=*/int64Dtype, /*layout=*/none, + /*device=*/none, /*pin_memory=*/none); - Value constMinusOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(-1)); + Value constMinusOne = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(-1)); auto unsqzTensorInfo = unsqueezeTensor(rewriter, op, rangeN, /*dim=*/constMinusOne); if (failed(unsqzTensorInfo)) { @@ -1827,7 +1836,7 @@ class DecomposeAtenEyeMOp : public OpRewritePattern { context, cast(op.getType()).getSizes(), IntegerType::get(context, 1)); Value eqTensor = - rewriter.create(loc, eqType, unsqzRangeN, rangeM); + AtenEqTensorOp::create(rewriter, loc, eqType, unsqzRangeN, rangeM); Value dtype = op.getDtype(); if (isa(dtype.getType())) { @@ -1835,11 +1844,11 @@ class DecomposeAtenEyeMOp : public OpRewritePattern { return success(); } else { auto zero = - rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); + ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(0.0)); auto one = - rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); - Value outTensor = - rewriter.create(loc, outType, eqTensor, one, zero); + ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(1.0)); + Value outTensor = AtenWhereScalarOp::create(rewriter, loc, outType, + eqTensor, one, zero); rewriter.replaceOp(op, outTensor); return success(); } @@ -1871,10 +1880,11 @@ class DecomposeAtenIsinfOp : public OpRewritePattern { Value self = op.getSelf(); mlir::FloatType f64Type = rewriter.getF64Type(); - Value inf = rewriter.create( - loc, rewriter.getFloatAttr( - f64Type, APFloat::getInf(f64Type.getFloatSemantics()))); - Value abs = rewriter.create(loc, self.getType(), self); + Value inf = ConstantFloatOp::create( + rewriter, loc, + rewriter.getFloatAttr(f64Type, + APFloat::getInf(f64Type.getFloatSemantics()))); + Value abs = AtenAbsOp::create(rewriter, loc, self.getType(), self); rewriter.replaceOpWithNewOp(op, op.getType(), abs, inf); return success(); } @@ -1887,8 +1897,8 @@ class DecomposeAtenIsneginfOp : public OpRewritePattern { LogicalResult matchAndRewrite(AtenIsneginfOp op, PatternRewriter &rewriter) const override { mlir::FloatType f64Type = rewriter.getF64Type(); - Value inf = rewriter.create( - op.getLoc(), + Value inf = ConstantFloatOp::create( + rewriter, op.getLoc(), rewriter.getFloatAttr( f64Type, APFloat::getInf(f64Type.getFloatSemantics(), true))); rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), @@ -1904,8 +1914,8 @@ class DecomposeAtenIsposinfOp : public OpRewritePattern { LogicalResult matchAndRewrite(AtenIsposinfOp op, PatternRewriter &rewriter) const override { mlir::FloatType f64Type = rewriter.getF64Type(); - Value inf = rewriter.create( - op.getLoc(), + Value inf = ConstantFloatOp::create( + rewriter, op.getLoc(), rewriter.getFloatAttr(f64Type, APFloat::getInf(f64Type.getFloatSemantics()))); rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), @@ -1960,8 +1970,8 @@ class DecomposeAtenAtleast1dOp : public OpRewritePattern { auto inpType = cast(input.getType()); SmallVector inputShape(inpType.getSizes()); if (inputShape.empty()) { - Value zero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); + Value zero = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(0)); rewriter.replaceOpWithNewOp(op, opType, input, zero); return success(); } @@ -2009,9 +2019,9 @@ class DecomposeAtenAtleast2dOp : public OpRewritePattern { auto atleast1dResType = rewriter.getType( atleast1dResShape, inputType.getOptionalDtype()); auto atleast1dRes = - rewriter.create(loc, atleast1dResType, input); - Value zero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); + AtenAtleast1dOp::create(rewriter, loc, atleast1dResType, input); + Value zero = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(0)); rewriter.replaceOpWithNewOp(op, opType, atleast1dRes, zero); return success(); @@ -2167,30 +2177,30 @@ class DecomposeAten_TrilinearOp : public OpRewritePattern { SmallVector sortedExpand1 = expand1; std::sort(sortedExpand1.begin(), sortedExpand1.end()); for (auto expand : sortedExpand1) { - Value expandDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(expand)); + Value expandDim = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(expand)); input1 = *unsqueezeTensor(rewriter, op, input1, expandDim); } SmallVector sortedExpand2 = expand2; std::sort(sortedExpand2.begin(), sortedExpand2.end()); for (auto expand : sortedExpand2) { - Value expandDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(expand)); + Value expandDim = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(expand)); input2 = *unsqueezeTensor(rewriter, op, input2, expandDim); } SmallVector sortedExpand3 = expand3; std::sort(sortedExpand3.begin(), sortedExpand3.end()); for (auto expand : sortedExpand3) { - Value expandDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(expand)); + Value expandDim = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(expand)); input3 = *unsqueezeTensor(rewriter, op, input3, expandDim); } // Apply multiplication operation. auto mul1 = - rewriter.create(loc, op.getType(), input1, input2); + AtenMulTensorOp::create(rewriter, loc, op.getType(), input1, input2); auto mul2 = - rewriter.create(loc, op.getType(), mul1, input3); + AtenMulTensorOp::create(rewriter, loc, op.getType(), mul1, input3); // Apply sum operation. // Parse sumDim in descending order to avoid any issues with the @@ -2199,8 +2209,8 @@ class DecomposeAten_TrilinearOp : public OpRewritePattern { SmallVector sortedSumDims = sumDim; std::sort(sortedSumDims.rbegin(), sortedSumDims.rend()); for (int64_t dim : sortedSumDims) { - Value dimValue = rewriter.create( - loc, rewriter.getI64IntegerAttr(dim)); + Value dimValue = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(dim)); result = createSumAlongDimension(rewriter, loc, op, result, dimValue, false); } @@ -2251,11 +2261,11 @@ class DecomposeAtenTraceOp : public OpRewritePattern { return rewriter.notifyMatchFailure( op, "Expected input tensor to have rank 2."); - Value none = rewriter.create(loc); + Value none = ConstantNoneOp::create(rewriter, loc); Value zero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0)); Value one = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); BaseTensorType inputType = cast(self.getType()); Value output = op.getResult(); @@ -2268,11 +2278,12 @@ class DecomposeAtenTraceOp : public OpRewritePattern { Type diagonalType = inputType.getWithSizesAndDtype( llvm::ArrayRef(diagonalShape), elementType); - Value diagonal = rewriter.create( - loc, diagonalType, /*input=*/self, /*offset=*/zero, /*dim1=*/zero, - /*dim2=*/one); - Value sum = rewriter.create(loc, outputType, /*self=*/diagonal, - /*dtype=*/none); + Value diagonal = + AtenDiagonalOp::create(rewriter, loc, diagonalType, /*input=*/self, + /*offset=*/zero, /*dim1=*/zero, + /*dim2=*/one); + Value sum = AtenSumOp::create(rewriter, loc, outputType, /*self=*/diagonal, + /*dtype=*/none); rewriter.replaceOp(op, sum); return success(); } @@ -2300,14 +2311,14 @@ static Value getSoftmaxResult(OpTy op, Value self, Type resultType, Value unNormalized = createTensorSub(rewriter, loc, self.getType(), self, xMax); Value unNormalizedExp = - rewriter.create(loc, self.getType(), unNormalized); + AtenExpOp::create(rewriter, loc, self.getType(), unNormalized); Value sum = createSumAlongDimension(rewriter, loc, op, unNormalizedExp, dim, /*keepDim=*/true); if (!sum) return nullptr; - Value result = rewriter.create(loc, self.getType(), - unNormalizedExp, sum); + Value result = AtenDivTensorOp::create(rewriter, loc, self.getType(), + unNormalizedExp, sum); if (resultType != accumulatorType) result = convertTensorToDtype(rewriter, loc, result, cast(resultType).getDtype()); @@ -2336,10 +2347,10 @@ class DecomposeAtenSoftmaxIntOp : public OpRewritePattern { // If `dtype` arg is non-none then convert the input to `dtype`. if (!isa(op.getDtype().getType())) { Location loc = op.getLoc(); - Value none = rewriter.create(loc); - Value cstFalse = rewriter.create(loc, false); - self = rewriter.create( - loc, resultTensorType, self, + Value none = ConstantNoneOp::create(rewriter, loc); + Value cstFalse = ConstantBoolOp::create(rewriter, loc, false); + self = AtenToDtypeOp::create( + rewriter, loc, resultTensorType, self, getDtypeIntValueForType(rewriter, loc, resultTensorDtype), /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); } @@ -2385,10 +2396,10 @@ class DecomposeAten_SoftmaxOp : public OpRewritePattern { // that of output's. if (halfToFloat) { Location loc = op.getLoc(); - Value none = rewriter.create(loc); - Value cstFalse = rewriter.create(loc, false); - self = rewriter.create( - loc, resultTensorType, self, + Value none = ConstantNoneOp::create(rewriter, loc); + Value cstFalse = ConstantBoolOp::create(rewriter, loc, false); + self = AtenToDtypeOp::create( + rewriter, loc, resultTensorType, self, getDtypeIntValueForType(rewriter, loc, resultTensorDtype), /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); } @@ -2431,8 +2442,8 @@ class DecomposeAten_SafeSoftmaxOp return rewriter.notifyMatchFailure(op, "dim int is not valid"); Location loc = op.getLoc(); - Value softmax = rewriter.create( - loc, op.getType(), op.getSelf(), op.getDim(), op.getDtype()); + Value softmax = AtenSoftmaxIntOp::create( + rewriter, loc, op.getType(), op.getSelf(), op.getDim(), op.getDtype()); Type resultTensorDtype = resultTensorType.getDtype(); @@ -2443,16 +2454,16 @@ class DecomposeAten_SafeSoftmaxOp auto boolDtype = rewriter.getI1Type(); auto boolTensorType = resultTensorType.getWithSizesAndDtype(sizes, boolDtype); - Value masked = rewriter.create(loc, boolTensorType, - op.getSelf(), negInfinity); + Value masked = AtenEqScalarOp::create(rewriter, loc, boolTensorType, + op.getSelf(), negInfinity); sizes[dimInt] = 1; auto maskedRowsType = resultTensorType.getWithSizesAndDtype(sizes, boolDtype); - Value cstTrue = - rewriter.create(loc, rewriter.getBoolAttr(true)); - Value maskedRows = rewriter.create( - loc, maskedRowsType, masked, op.getDim(), cstTrue); + Value cstTrue = Torch::ConstantBoolOp::create(rewriter, loc, + rewriter.getBoolAttr(true)); + Value maskedRows = AtenAllDimOp::create(rewriter, loc, maskedRowsType, + masked, op.getDim(), cstTrue); Value cstZero = getConstantWithGivenDtypeAndValue(rewriter, loc, 0.0, resultTensorDtype); rewriter.replaceOpWithNewOp( @@ -2485,7 +2496,7 @@ class DecomposeAten_SoftmaxBackwardDataOp return rewriter.notifyMatchFailure(op, "Only support floating type"); Value newGrad = - rewriter.create(loc, tensorType, gradOutput, output); + AtenMulTensorOp::create(rewriter, loc, tensorType, gradOutput, output); Value result = createSoftmaxBackwardCommonKernel( rewriter, loc, op, tensorType, newGrad, output, newGrad, dim); if (!result) @@ -2521,9 +2532,9 @@ class DecomposeAtenTanhBackwardOp return rewriter.notifyMatchFailure(op, "Only support floating type"); Value tanhSquare = - rewriter.create(loc, tensorType, output, output); - Value gradMulTanhSquare = rewriter.create( - loc, tensorType, tanhSquare, gradOutput); + AtenMulTensorOp::create(rewriter, loc, tensorType, output, output); + Value gradMulTanhSquare = AtenMulTensorOp::create(rewriter, loc, tensorType, + tanhSquare, gradOutput); Value newGrad = createTensorSub(rewriter, loc, tensorType, gradOutput, gradMulTanhSquare); @@ -2551,7 +2562,7 @@ class DecomposeAten_LogSoftmaxBackwardDataOp if (!tensorType.hasDtype() || !isa(tensorType.getDtype())) return rewriter.notifyMatchFailure(op, "Only support floating type"); - Value expOut = rewriter.create(loc, tensorType, output); + Value expOut = AtenExpOp::create(rewriter, loc, tensorType, output); Value result = createSoftmaxBackwardCommonKernel( rewriter, loc, op, tensorType, gradOutput, expOut, gradOutput, dim); if (!result) @@ -2615,8 +2626,8 @@ class DecomposeAtenAminAmaxOp : public OpRewritePattern { dimInt = toPositiveDim(dimInt, inputRank); if (!isValidDim(dimInt, inputRank)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); - Value dim = rewriter.create( - loc, rewriter.getI64IntegerAttr(dimInt)); + Value dim = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(dimInt)); // The input to the next invocation of aten.max.dim is the output of the // previous aten.max.dim op. static_assert(std::is_same_v || @@ -2671,8 +2682,8 @@ class DecomposeAtenArgMinMaxOp : public OpRewritePattern { // first the input tensor is flattened to 1d tensor and then the reduction // happens on the 0th dimension. if (isa(dim.getType())) { - Value zero = rewriter.create(loc, 0); - Value falseValue = rewriter.create(loc, false); + Value zero = ConstantIntOp::create(rewriter, loc, 0); + Value falseValue = ConstantBoolOp::create(rewriter, loc, false); if (inputType.getSizes().size() > 1) { int64_t flattenSize = Torch::kUnknownSize; if (inputType.areAllSizesKnown()) { @@ -2682,43 +2693,40 @@ class DecomposeAtenArgMinMaxOp : public OpRewritePattern { } auto flattenType = cast(inputType.getWithSizesAndDtype( {flattenSize}, inputType.getOptionalDtype())); - Value end = rewriter.create( - loc, rewriter.getI64IntegerAttr(inputRank - 1)); - input = rewriter.create(loc, flattenType, input, - zero, end); + Value end = ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(inputRank - 1)); + input = AtenFlattenUsingIntsOp::create(rewriter, loc, flattenType, + input, zero, end); } Value resultIndices = - rewriter - .create( - loc, - valueTensorType.getWithSizesAndDtype( - ArrayRef{}, valueTensorType.getOptionalDtype()), - indicesTensorType.getWithSizesAndDtype( - ArrayRef{}, - indicesTensorType.getOptionalDtype()), - input, /*dim=*/zero, /*keepdim=*/falseValue) + DecompOpTy::create( + rewriter, loc, + valueTensorType.getWithSizesAndDtype( + ArrayRef{}, valueTensorType.getOptionalDtype()), + indicesTensorType.getWithSizesAndDtype( + ArrayRef{}, indicesTensorType.getOptionalDtype()), + input, /*dim=*/zero, /*keepdim=*/falseValue) .getIndices(); if (keepDim) { Value one = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); - Value dimList = rewriter.create( - loc, + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); + Value dimList = PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(Torch::IntType::get(rewriter.getContext())), SmallVector(inputRank, one)); - resultIndices = rewriter.create( - loc, - indicesTensorType.getWithSizesAndDtype( - SmallVector(inputRank, 1), - indicesTensorType.getOptionalDtype()), - resultIndices, dimList); + resultIndices = + AtenReshapeOp::create(rewriter, loc, + indicesTensorType.getWithSizesAndDtype( + SmallVector(inputRank, 1), + indicesTensorType.getOptionalDtype()), + resultIndices, dimList); } rewriter.replaceOp(op, resultIndices); return success(); } else { Value resultIndices = - rewriter - .create(loc, valueTensorType, indicesTensorType, - input, dim, op.getKeepdim()) + DecompOpTy::create(rewriter, loc, valueTensorType, indicesTensorType, + input, dim, op.getKeepdim()) .getIndices(); rewriter.replaceOp(op, resultIndices); return success(); @@ -2740,17 +2748,17 @@ class DecomposeAtenAminmaxOp : public OpRewritePattern { rewriter.getType(rewriter.getType()); Value dimList; if (isa(op.getDim().getType())) { - dimList = rewriter.create(loc, listType, - ArrayRef{}); + dimList = Torch::PrimListConstructOp::create(rewriter, loc, listType, + ArrayRef{}); } else { - dimList = rewriter.create( - loc, listType, ArrayRef{op.getDim()}); + dimList = Torch::PrimListConstructOp::create( + rewriter, loc, listType, ArrayRef{op.getDim()}); } - auto amin = rewriter.create( - loc, op.getMin().getType(), op.getSelf(), dimList, op.getKeepdim()); - auto amax = rewriter.create( - loc, op.getMax().getType(), op.getSelf(), dimList, op.getKeepdim()); + auto amin = AtenAminOp::create(rewriter, loc, op.getMin().getType(), + op.getSelf(), dimList, op.getKeepdim()); + auto amax = AtenAmaxOp::create(rewriter, loc, op.getMax().getType(), + op.getSelf(), dimList, op.getKeepdim()); rewriter.replaceOp(op, {amin, amax}); return success(); } @@ -2811,8 +2819,8 @@ class DecomposeAtenBucketizeTensorOp } // unsqueeze input at the last dim to make it broadcastable with boundaries - Value constMinusOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(-1)); + Value constMinusOne = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(-1)); auto unsqzTensorInfo = unsqueezeTensor(rewriter, op, input, /*dim=*/constMinusOne); if (failed(unsqzTensorInfo)) { @@ -2828,11 +2836,11 @@ class DecomposeAtenBucketizeTensorOp inputType.getWithSizesAndDtype(compareShape, rewriter.getI1Type()); Value compare; if (!right) { - compare = rewriter.create(loc, compareType, unsqzInput, - boundaries); + compare = AtenLeTensorOp::create(rewriter, loc, compareType, unsqzInput, + boundaries); } else { - compare = rewriter.create(loc, compareType, unsqzInput, - boundaries); + compare = AtenLtTensorOp::create(rewriter, loc, compareType, unsqzInput, + boundaries); } // convert the comparison results to float32 as the argmax op input, @@ -2844,28 +2852,28 @@ class DecomposeAtenBucketizeTensorOp // equal to) the boundary value Type indicesType = inputType.getWithSizesAndDtype( inputShape, rewriter.getIntegerType(64, IntegerType::Signed)); - Value constFalse = rewriter.create(loc, false); - Value indices = rewriter.create(loc, indicesType, compareF32, - /*dim=*/constMinusOne, - /*keepdim=*/constFalse); + Value constFalse = Torch::ConstantBoolOp::create(rewriter, loc, false); + Value indices = AtenArgmaxOp::create(rewriter, loc, indicesType, compareF32, + /*dim=*/constMinusOne, + /*keepdim=*/constFalse); // get the comparison results between each input element and the rightmost // boundary value Type withinUpperBoundType = inputType.getWithSizesAndDtype(inputShape, rewriter.getI1Type()); - Value withinUpperBound = rewriter.create( - loc, withinUpperBoundType, compare, /*dim=*/constMinusOne, + Value withinUpperBound = AtenSelectIntOp::create( + rewriter, loc, withinUpperBoundType, compare, /*dim=*/constMinusOne, /*index=*/constMinusOne); // If the input element is less than (or equal to) the rightmost boundary, // take the max index as result. Otherwise, the element is beyond the // rightmost boundary, so take the boundary size. - Value constZero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); + Value constZero = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(0)); Value upperBound = - rewriter.create(loc, boundaries, /*dim=*/constZero); - Value result = rewriter.create( - loc, indicesType, withinUpperBound, indices, upperBound); + AtenSizeIntOp::create(rewriter, loc, boundaries, /*dim=*/constZero); + Value result = AtenWhereScalarOtherOp::create( + rewriter, loc, indicesType, withinUpperBound, indices, upperBound); if (outInt32) { result = convertTensorToDtype( @@ -2896,7 +2904,7 @@ static Value getLogSoftmaxResult(OpTy op, PatternRewriter &rewriter) { return nullptr; Value shifted = createTensorSub(rewriter, loc, tensorType, self, xMax); - Value shiftedExp = rewriter.create(loc, tensorType, shifted); + Value shiftedExp = AtenExpOp::create(rewriter, loc, tensorType, shifted); Value shiftedSumExp = createSumAlongDimension(rewriter, loc, op, shiftedExp, dim, /*keepDim=*/true); @@ -2904,7 +2912,7 @@ static Value getLogSoftmaxResult(OpTy op, PatternRewriter &rewriter) { return nullptr; Value shiftedLogSumExp = - rewriter.create(loc, shiftedSumExp.getType(), shiftedSumExp); + AtenLogOp::create(rewriter, loc, shiftedSumExp.getType(), shiftedSumExp); Value result = createTensorSub(rewriter, loc, op.getType(), shifted, shiftedLogSumExp); return result; @@ -3003,9 +3011,9 @@ class DecomposeAtenLogCumsumExpOp Value dtypeVal = getDtypeIntValueForType(rewriter, loc, inputType.getDtype()); - Value expInput = rewriter.create(loc, resultType, input); - Value cumsum = rewriter.create(loc, resultType, expInput, - op.getDim(), dtypeVal); + Value expInput = AtenExpOp::create(rewriter, loc, resultType, input); + Value cumsum = AtenCumsumOp::create(rewriter, loc, resultType, expInput, + op.getDim(), dtypeVal); rewriter.replaceOpWithNewOp(op, resultType, cumsum); return success(); } @@ -3018,8 +3026,8 @@ class DecomposeAtenLogSigmoidOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenLogSigmoidOp op, PatternRewriter &rewriter) const override { - Value sigmoid = - rewriter.create(op.getLoc(), op.getType(), op.getSelf()); + Value sigmoid = AtenSigmoidOp::create(rewriter, op.getLoc(), op.getType(), + op.getSelf()); rewriter.replaceOpWithNewOp(op, op.getType(), sigmoid); return success(); } @@ -3038,11 +3046,11 @@ class DecomposeAtenLogAddExpOp : public OpRewritePattern { auto outTy = op.getType(); Value constantOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); - Value expSelf = rewriter.create(loc, outTy, self); - Value expOther = rewriter.create(loc, outTy, other); - Value addValue = rewriter.create(loc, outTy, expSelf, - expOther, constantOne); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); + Value expSelf = AtenExpOp::create(rewriter, loc, outTy, self); + Value expOther = AtenExpOp::create(rewriter, loc, outTy, other); + Value addValue = AtenAddTensorOp::create(rewriter, loc, outTy, expSelf, + expOther, constantOne); rewriter.replaceOpWithNewOp(op, outTy, addValue); return success(); } @@ -3061,11 +3069,11 @@ class DecomposeAtenLogAddExp2Op : public OpRewritePattern { auto outTy = op.getType(); Value constantOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); - Value expSelf = rewriter.create(loc, outTy, self); - Value expOther = rewriter.create(loc, outTy, other); - Value addValue = rewriter.create(loc, outTy, expSelf, - expOther, constantOne); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); + Value expSelf = AtenExp2Op::create(rewriter, loc, outTy, self); + Value expOther = AtenExp2Op::create(rewriter, loc, outTy, other); + Value addValue = AtenAddTensorOp::create(rewriter, loc, outTy, expSelf, + expOther, constantOne); rewriter.replaceOpWithNewOp(op, outTy, addValue); return success(); } @@ -3100,32 +3108,32 @@ class DecomposeAtenSoftshrinkOp : public OpRewritePattern { } Value zero = - rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); - Value neglambd = rewriter.create( - loc, rewriter.getF64FloatAttr(-lambd)); - Value poslambd = rewriter.create( - loc, rewriter.getF64FloatAttr(lambd)); + ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(0.0)); + Value neglambd = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(-lambd)); + Value poslambd = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(lambd)); Value constOneFloat = - rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(1.0)); auto boolResType = resTy.getWithSizesAndDtype(resTy.getSizes(), rewriter.getI1Type()); Value posMask = - rewriter.create(loc, boolResType, self, poslambd); + AtenGtScalarOp::create(rewriter, loc, boolResType, self, poslambd); Value negMask = - rewriter.create(loc, boolResType, self, neglambd); + AtenLtScalarOp::create(rewriter, loc, boolResType, self, neglambd); - Value posValue = rewriter.create(loc, resTy, self, - poslambd, constOneFloat); - Value negValue = rewriter.create(loc, resTy, self, - neglambd, constOneFloat); + Value posValue = AtenSubScalarOp::create(rewriter, loc, resTy, self, + poslambd, constOneFloat); + Value negValue = AtenAddScalarOp::create(rewriter, loc, resTy, self, + neglambd, constOneFloat); - Value result = rewriter.create(loc, resTy, posMask, - posValue, zero); - result = - rewriter.create(loc, resTy, negMask, negValue, result); + Value result = AtenWhereScalarOtherOp::create(rewriter, loc, resTy, posMask, + posValue, zero); + result = AtenWhereSelfOp::create(rewriter, loc, resTy, negMask, negValue, + result); rewriter.replaceOp(op, result); return success(); @@ -3161,24 +3169,24 @@ class DecomposeAtenHardshrinkOp : public OpRewritePattern { } Value zero = - rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); - Value neglambd = rewriter.create( - loc, rewriter.getF64FloatAttr(-lambd)); - Value poslambd = rewriter.create( - loc, rewriter.getF64FloatAttr(lambd)); + ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(0.0)); + Value neglambd = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(-lambd)); + Value poslambd = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(lambd)); auto boolResType = resTy.getWithSizesAndDtype(resTy.getSizes(), rewriter.getI1Type()); Value posMask = - rewriter.create(loc, boolResType, self, poslambd); + AtenGtScalarOp::create(rewriter, loc, boolResType, self, poslambd); Value negMask = - rewriter.create(loc, boolResType, self, neglambd); + AtenLtScalarOp::create(rewriter, loc, boolResType, self, neglambd); - Value result = rewriter.create(loc, resTy, posMask, - self, zero); + Value result = AtenWhereScalarOtherOp::create(rewriter, loc, resTy, posMask, + self, zero); result = - rewriter.create(loc, resTy, negMask, self, result); + AtenWhereSelfOp::create(rewriter, loc, resTy, negMask, self, result); rewriter.replaceOp(op, result); return success(); @@ -3275,10 +3283,10 @@ class DecomposeAtenRenormOp : public OpRewritePattern { "Unimplemented: dim not constant int"); // Define all constants - Value cstTrue = rewriter.create(loc, true); - Value cstZero = rewriter.create(loc, 0); - Value cstOne = rewriter.create(loc, 1); - Value cstNone = rewriter.create(loc); + Value cstTrue = ConstantBoolOp::create(rewriter, loc, true); + Value cstZero = Torch::ConstantIntOp::create(rewriter, loc, 0); + Value cstOne = Torch::ConstantIntOp::create(rewriter, loc, 1); + Value cstNone = ConstantNoneOp::create(rewriter, loc); // Arragne reduce_dims tensor (vector), [0, 1, ... , dim-1, dim+1, ... , // ndim-1] @@ -3287,14 +3295,14 @@ class DecomposeAtenRenormOp : public OpRewritePattern { if (i == (uint64_t)dimInt) continue; - Value constI = rewriter.create( - loc, rewriter.getI64IntegerAttr(i)); + Value constI = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(i)); reduceDimsVector.push_back(constI); } - Value reduceDimsList = rewriter.create( - loc, + Value reduceDimsList = Torch::PrimListConstructOp::create( + rewriter, loc, rewriter.getType(rewriter.getType()), reduceDimsVector); @@ -3305,7 +3313,7 @@ class DecomposeAtenRenormOp : public OpRewritePattern { inputSize[i] = 1; inputSizeValue.push_back( - rewriter.create(loc, inputSize[i])); + Torch::ConstantIntOp::create(rewriter, loc, inputSize[i])); } // Prepare arguments for linalg.vector_norm @@ -3321,24 +3329,25 @@ class DecomposeAtenRenormOp : public OpRewritePattern { vectorNormOutType = resType.getWithSizesAndDtype(inputSize, dtype); } - auto norm = rewriter.create( - loc, vectorNormOutType, self, p, reduceDimsList, cstTrue, dtypeValue); + auto norm = + AtenLinalgVectorNormOp::create(rewriter, loc, vectorNormOutType, self, + p, reduceDimsList, cstTrue, dtypeValue); // Define epsiolon constant 10^-7 mlir::FloatType f64Type = rewriter.getF64Type(); - Value epsValue = rewriter.create( - loc, rewriter.getFloatAttr(f64Type, 1e-7)); + Value epsValue = ConstantFloatOp::create( + rewriter, loc, rewriter.getFloatAttr(f64Type, 1e-7)); - Value normPlusEps = rewriter.create( - loc, vectorNormOutType, norm, epsValue, cstOne); + Value normPlusEps = AtenAddScalarOp::create( + rewriter, loc, vectorNormOutType, norm, epsValue, cstOne); - Value maxnormTensorValue = rewriter.create( - loc, normPlusEps.getType(), normPlusEps, maxnorm, cstNone, cstNone, - cstNone, cstNone, cstNone); + Value maxnormTensorValue = AtenFullLikeOp::create( + rewriter, loc, normPlusEps.getType(), normPlusEps, maxnorm, cstNone, + cstNone, cstNone, cstNone, cstNone); // Divide maxnorm and normPlusEps - auto divideMaxnormAndNorm = rewriter.create( - loc, vectorNormOutType, maxnormTensorValue, normPlusEps); + auto divideMaxnormAndNorm = AtenDivTensorOp::create( + rewriter, loc, vectorNormOutType, maxnormTensorValue, normPlusEps); // Next few lines corespond to this pythorch code: norm_factor = // torch.where(norm > maxnorm, maxnorm / (norm + eps), 1.0) @@ -3347,24 +3356,25 @@ class DecomposeAtenRenormOp : public OpRewritePattern { rewriter.getI1Type()); Value greaterThanMaxnorm = - rewriter.create(loc, boolTensorType, norm, maxnorm); + AtenGtScalarOp::create(rewriter, loc, boolTensorType, norm, maxnorm); - Value cstOnetensor = rewriter.create( - loc, normPlusEps.getType(), normPlusEps, cstOne, cstNone, cstNone, - cstNone, cstNone, cstNone); + Value cstOnetensor = AtenFullLikeOp::create( + rewriter, loc, normPlusEps.getType(), normPlusEps, cstOne, cstNone, + cstNone, cstNone, cstNone, cstNone); - auto normFactor = rewriter.create( - loc, vectorNormOutType, greaterThanMaxnorm, divideMaxnormAndNorm, - cstOnetensor); + auto normFactor = AtenWhereSelfOp::create( + rewriter, loc, vectorNormOutType, greaterThanMaxnorm, + divideMaxnormAndNorm, cstOnetensor); // Converte norm_factor to input dtype - Value normFactorFinal = rewriter.create( - loc, resType.getWithSizesAndDtype(inputSize, resType.getDtype()), - normFactor, getDtypeIntValueForType(rewriter, loc, resType.getDtype())); + Value normFactorFinal = PrimsConvertElementTypeOp::create( + rewriter, loc, + resType.getWithSizesAndDtype(inputSize, resType.getDtype()), normFactor, + getDtypeIntValueForType(rewriter, loc, resType.getDtype())); // Multiply input tensor with norm factor - auto output = rewriter.create(loc, self.getType(), self, - normFactorFinal); + auto output = AtenMulTensorOp::create(rewriter, loc, self.getType(), self, + normFactorFinal); rewriter.replaceOpWithNewOp(op, self.getType(), output, /*memory_format*/ cstZero); @@ -3420,61 +3430,62 @@ class DecomposeAtenLinalgCrossOp : public OpRewritePattern { Type broadcastType = ValueTensorType::get( op.getContext(), llvm::ArrayRef(broadcastShape), dtype); - Value indexBroadcastShapeTorchList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), + Value indexBroadcastShapeTorchList = PrimListConstructOp::create( + rewriter, loc, + Torch::ListType::get(Torch::IntType::get(op.getContext())), broadcastShapeValue); // broadcast tensors to common shape - auto a = rewriter.create(loc, broadcastType, self, - indexBroadcastShapeTorchList); - auto b = rewriter.create(loc, broadcastType, other, - indexBroadcastShapeTorchList); + auto a = AtenBroadcastToOp::create(rewriter, loc, broadcastType, self, + indexBroadcastShapeTorchList); + auto b = AtenBroadcastToOp::create(rewriter, loc, broadcastType, other, + indexBroadcastShapeTorchList); // create constants - Value constOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - Value constTwo = rewriter.create( - loc, rewriter.getI64IntegerAttr(2)); - Value constThree = rewriter.create( - loc, rewriter.getI64IntegerAttr(3)); - Value none = rewriter.create(loc); + Value constOne = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(1)); + Value constTwo = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(2)); + Value constThree = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(3)); + Value none = ConstantNoneOp::create(rewriter, loc); // idx = torch.arange(3) auto outType = dyn_cast(opType); auto arangeType = outType.getWithSizesAndDtype( llvm::ArrayRef(3), IntegerType::get(op.getContext(), 64, IntegerType::Signed)); - auto idx = rewriter.create( - loc, arangeType, constThree, /*dtype=*/none, /*layout=*/none, - /*device=*/none, /*pin_memory=*/none); + auto idx = AtenArangeOp::create(rewriter, loc, arangeType, constThree, + /*dtype=*/none, /*layout=*/none, + /*device=*/none, /*pin_memory=*/none); // (idx + 1) and (idx + 2) - auto idxPlusOne = rewriter.create(loc, arangeType, idx, - constOne, constOne); - auto idxPlusTwo = rewriter.create(loc, arangeType, idx, - constTwo, constOne); + auto idxPlusOne = AtenAddScalarOp::create(rewriter, loc, arangeType, idx, + constOne, constOne); + auto idxPlusTwo = AtenAddScalarOp::create(rewriter, loc, arangeType, idx, + constTwo, constOne); // (idx + 1) % 3 and (idx + 2) % 3 - auto idxPlusOneRemainderThree = rewriter.create( - loc, arangeType, idxPlusOne, constThree); - auto idxPlusTwoRemainderThree = rewriter.create( - loc, arangeType, idxPlusTwo, constThree); + auto idxPlusOneRemainderThree = AtenRemainderScalarOp::create( + rewriter, loc, arangeType, idxPlusOne, constThree); + auto idxPlusTwoRemainderThree = AtenRemainderScalarOp::create( + rewriter, loc, arangeType, idxPlusTwo, constThree); // a.index_select(dim, (idx + 1) % 3) * b.index_select(dim, (idx + 2) % 3) - auto idxSelectAPlusOne = rewriter.create( - loc, opType, a, dim, idxPlusOneRemainderThree); - auto idxSelectBPlusTwo = rewriter.create( - loc, opType, b, dim, idxPlusTwoRemainderThree); - auto firstMul = rewriter.create( - loc, opType, idxSelectAPlusOne, idxSelectBPlusTwo); + auto idxSelectAPlusOne = AtenIndexSelectOp::create( + rewriter, loc, opType, a, dim, idxPlusOneRemainderThree); + auto idxSelectBPlusTwo = AtenIndexSelectOp::create( + rewriter, loc, opType, b, dim, idxPlusTwoRemainderThree); + auto firstMul = AtenMulTensorOp::create( + rewriter, loc, opType, idxSelectAPlusOne, idxSelectBPlusTwo); // a.index_select(dim, (idx + 2) % 3) * b.index_select(dim, (idx + 1) % 3) - auto idxSelectAPlusTwo = rewriter.create( - loc, opType, a, dim, idxPlusTwoRemainderThree); - auto idxSelectBPlusOne = rewriter.create( - loc, opType, b, dim, idxPlusOneRemainderThree); - auto secondMul = rewriter.create( - loc, opType, idxSelectAPlusTwo, idxSelectBPlusOne); + auto idxSelectAPlusTwo = AtenIndexSelectOp::create( + rewriter, loc, opType, a, dim, idxPlusTwoRemainderThree); + auto idxSelectBPlusOne = AtenIndexSelectOp::create( + rewriter, loc, opType, b, dim, idxPlusOneRemainderThree); + auto secondMul = AtenMulTensorOp::create( + rewriter, loc, opType, idxSelectAPlusTwo, idxSelectBPlusOne); // subtract the results of the two multiplications from above rewriter.replaceOpWithNewOp(op, opType, firstMul, @@ -3498,14 +3509,14 @@ class DecomposeAtenLinalgSlogdetOp SmallVector results = op.getResults(); Location loc = op.getLoc(); Value input = op.getA(); - Value determinant = rewriter.create( - loc, results[0].getType(), input); + Value determinant = Torch::AtenLinalgDetOp::create( + rewriter, loc, results[0].getType(), input); Value sign = - rewriter.create(loc, determinant.getType(), determinant); + AtenSgnOp::create(rewriter, loc, determinant.getType(), determinant); Value abs_det = - rewriter.create(loc, determinant.getType(), determinant); + AtenAbsOp::create(rewriter, loc, determinant.getType(), determinant); Value ln_abs_det = - rewriter.create(loc, abs_det.getType(), abs_det); + AtenLogOp::create(rewriter, loc, abs_det.getType(), abs_det); rewriter.replaceAllUsesWith(results[0], sign); rewriter.replaceAllUsesWith(results[1], ln_abs_det); rewriter.eraseOp(op); @@ -3527,8 +3538,8 @@ class DecomposeAten_LinalgDetOp : public OpRewritePattern { op, "unsupported: _linalg_det results: LU and pivot"); Location loc = op.getLoc(); Value input = op.getA(); - Value determinant = rewriter.create( - loc, results[0].getType(), input); + Value determinant = Torch::AtenLinalgDetOp::create( + rewriter, loc, results[0].getType(), input); rewriter.replaceAllUsesWith(results[0], determinant); rewriter.eraseOp(op); return success(); @@ -3604,7 +3615,7 @@ class DecomposeAtenPixelShuffleOp dimensionConstants.reserve(inRank + 2); for (unsigned i = 0; i < inRank + 2; ++i) { dimensionConstants.push_back( - rewriter.create(loc, rewriter.getI64IntegerAttr(i))); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(i))); } SmallVector leadingDims; @@ -3637,33 +3648,34 @@ class DecomposeAtenPixelShuffleOp permutation.push_back(dimensionConstants[nLeadingDims + d]); } - Value permuteDimsOrder = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(op->getContext())), + Value permuteDimsOrder = PrimListConstructOp::create( + rewriter, loc, + Torch::ListType::get(Torch::IntType::get(op->getContext())), permutation); // Split input channel inC -> (inC, factorSquared) auto partiallyExpanded = - rewriter - .create( - loc, - getTensorTypeFromShapeValues(partiallyExpandedShape, - inOptionalDType), - inValue, dimensionConstants[nLeadingDims], outC) + PrimsSplitDimOp::create(rewriter, loc, + getTensorTypeFromShapeValues( + partiallyExpandedShape, inOptionalDType), + inValue, dimensionConstants[nLeadingDims], outC) .getResult(); // Split new dimension factorSquared -> (factor, factor) - auto fullyExpanded = rewriter.create( - loc, getTensorTypeFromShapeValues(prePermuteShape, inOptionalDType), + auto fullyExpanded = PrimsSplitDimOp::create( + rewriter, loc, + getTensorTypeFromShapeValues(prePermuteShape, inOptionalDType), partiallyExpanded, dimensionConstants[nLeadingDims + 1], factor); // Perform the permutation - auto permuted = rewriter.create( - loc, getTensorTypeFromShapeValues(postPermuteShape, inOptionalDType), + auto permuted = AtenPermuteOp::create( + rewriter, loc, + getTensorTypeFromShapeValues(postPermuteShape, inOptionalDType), fullyExpanded, permuteDimsOrder); // Collapse final 2 dimension - auto partiallyCollapsed = rewriter.create( - loc, + auto partiallyCollapsed = PrimsCollapseOp::create( + rewriter, loc, getTensorTypeFromShapeValues(partiallyCollapsedShape, inOptionalDType), permuted, dimensionConstants[nLeadingDims + 3], dimensionConstants[nLeadingDims + 4]); @@ -3749,7 +3761,7 @@ class DecomposeAtenPixelUnshuffleOp dimensionConstants.reserve(inRank + 2); for (unsigned i = 0; i < inRank + 2; ++i) { dimensionConstants.push_back( - rewriter.create(loc, rewriter.getI64IntegerAttr(i))); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(i))); } SmallVector leadingDims; @@ -3779,8 +3791,9 @@ class DecomposeAtenPixelUnshuffleOp permutation.push_back(dimensionConstants[nLeadingDims + d]); } - Value permuteDimsOrder = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(op->getContext())), + Value permuteDimsOrder = PrimListConstructOp::create( + rewriter, loc, + Torch::ListType::get(Torch::IntType::get(op->getContext())), permutation); SmallVector heightSplitShape = leadingDims; @@ -3788,21 +3801,22 @@ class DecomposeAtenPixelUnshuffleOp // Split input channel inH -> (outH, factor) auto partiallyExpanded = - rewriter - .create( - loc, - getTensorTypeFromShapeValues(heightSplitShape, inOptionalDType), - inValue, dimensionConstants[nLeadingDims + 1], outH) + PrimsSplitDimOp::create( + rewriter, loc, + getTensorTypeFromShapeValues(heightSplitShape, inOptionalDType), + inValue, dimensionConstants[nLeadingDims + 1], outH) .getResult(); // Split new dimension inW -> (outW, factor) - auto fullyExpanded = rewriter.create( - loc, getTensorTypeFromShapeValues(prePermuteShape, inOptionalDType), + auto fullyExpanded = PrimsSplitDimOp::create( + rewriter, loc, + getTensorTypeFromShapeValues(prePermuteShape, inOptionalDType), partiallyExpanded, dimensionConstants[nLeadingDims + 3], outW); // Perform the permutation - auto permuted = rewriter.create( - loc, getTensorTypeFromShapeValues(postPermuteShape, inOptionalDType), + auto permuted = AtenPermuteOp::create( + rewriter, loc, + getTensorTypeFromShapeValues(postPermuteShape, inOptionalDType), fullyExpanded, permuteDimsOrder); // Collapse final 2 dimensions back to original rank @@ -3892,7 +3906,7 @@ class DecomposeAtenChannelShuffleOp dimensionConstants.reserve(inRank + 1); for (unsigned i = 0; i < inRank + 1; ++i) { dimensionConstants.push_back( - rewriter.create(loc, rewriter.getI64IntegerAttr(i))); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(i))); } Value batchDimSize = rewriter.createOrFold( @@ -3916,8 +3930,9 @@ class DecomposeAtenChannelShuffleOp permutation.push_back(dimensionConstants[i]); } - Value permuteDimsOrder = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(op->getContext())), + Value permuteDimsOrder = PrimListConstructOp::create( + rewriter, loc, + Torch::ListType::get(Torch::IntType::get(op->getContext())), permutation); const auto inOptionalDType = inType.getOptionalDtype(); @@ -3927,15 +3942,16 @@ class DecomposeAtenChannelShuffleOp // Split input channel inC -> (groups, inC/groups) auto expandedTensor = - rewriter - .create( - loc, getTensorTypeFromShapeValues(splitShape, inOptionalDType), - inValue, dimC, tempC) + PrimsSplitDimOp::create( + rewriter, loc, + getTensorTypeFromShapeValues(splitShape, inOptionalDType), inValue, + dimC, tempC) .getResult(); // Perform the permutation - auto permuted = rewriter.create( - loc, getTensorTypeFromShapeValues(permuteShape, inOptionalDType), + auto permuted = AtenPermuteOp::create( + rewriter, loc, + getTensorTypeFromShapeValues(permuteShape, inOptionalDType), expandedTensor, permuteDimsOrder); // Collapse (C, groups) back into a single channel dimension @@ -3952,12 +3968,12 @@ static Value getRelu6Results(PatternRewriter &rewriter, Location loc, Value input) { BaseTensorType inputType = cast(input.getType()); - Value relu = rewriter.create(loc, inputType, input); - Value cst6 = - rewriter.create(loc, rewriter.getI64IntegerAttr(6)); + Value relu = AtenReluOp::create(rewriter, loc, inputType, input); + Value cst6 = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(6)); Value sixTensor = createRank0Tensor(rewriter, loc, inputType, cst6); Value relu6Out = - rewriter.create(loc, inputType, relu, sixTensor); + AtenMinimumOp::create(rewriter, loc, inputType, relu, sixTensor); return relu6Out; } @@ -3990,19 +4006,19 @@ class DecomposeAtenHardswishOp : public OpRewritePattern { Value input = op.getSelf(); Type inputType = input.getType(); - Value constantOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - Value constantThree = rewriter.create( - loc, rewriter.getI64IntegerAttr(3)); - Value constantSix = rewriter.create( - loc, rewriter.getI64IntegerAttr(6)); - Value inputPlusThree = rewriter.create( - loc, inputType, input, constantThree, /*alpha=*/constantOne); + Value constantOne = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(1)); + Value constantThree = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(3)); + Value constantSix = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(6)); + Value inputPlusThree = AtenAddScalarOp::create( + rewriter, loc, inputType, input, constantThree, /*alpha=*/constantOne); Value relu6 = getRelu6Results(rewriter, loc, inputPlusThree); Value divTensor = - rewriter.create(loc, inputType, relu6, constantSix); + AtenDivScalarOp::create(rewriter, loc, inputType, relu6, constantSix); Value mulTensor = - rewriter.create(loc, inputType, divTensor, input); + AtenMulTensorOp::create(rewriter, loc, inputType, divTensor, input); rewriter.replaceOp(op, mulTensor); return success(); @@ -4026,18 +4042,19 @@ class DecomposeAtenLeakyReluOp : public OpRewritePattern { } Value constantZero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0)); Value constantOne = - rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(1.0)); Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero); Value positiveOutput = - rewriter.create(loc, resType, zeroTensor, input); + AtenMaximumOp::create(rewriter, loc, resType, zeroTensor, input); Value negativeOutput = - rewriter.create(loc, resType, zeroTensor, input); - Value scaledNegativeOutput = rewriter.create( - loc, resType, negativeOutput, negativeSlope); - Value leakyReluOutput = rewriter.create( - loc, resType, positiveOutput, scaledNegativeOutput, constantOne); + AtenMinimumOp::create(rewriter, loc, resType, zeroTensor, input); + Value scaledNegativeOutput = AtenMulScalarOp::create( + rewriter, loc, resType, negativeOutput, negativeSlope); + Value leakyReluOutput = + AtenAddTensorOp::create(rewriter, loc, resType, positiveOutput, + scaledNegativeOutput, constantOne); rewriter.replaceOp(op, leakyReluOutput); return success(); @@ -4070,18 +4087,19 @@ class DecomposeAtenLeakyReluBackwardOp op, "unimplemented: self_is_result should be false"); Value constantZero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0)); Value constantOne = - rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(1.0)); Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero); Value positiveOutput = - rewriter.create(loc, resType, zeroTensor, gradOutput); + AtenMaximumOp::create(rewriter, loc, resType, zeroTensor, gradOutput); Value negativeOutput = - rewriter.create(loc, resType, zeroTensor, input); - Value scaledNegativeOutput = rewriter.create( - loc, resType, negativeOutput, negativeSlope); - Value leakyReluBackwardOutput = rewriter.create( - loc, resType, positiveOutput, scaledNegativeOutput, constantOne); + AtenMinimumOp::create(rewriter, loc, resType, zeroTensor, input); + Value scaledNegativeOutput = AtenMulScalarOp::create( + rewriter, loc, resType, negativeOutput, negativeSlope); + Value leakyReluBackwardOutput = + AtenAddTensorOp::create(rewriter, loc, resType, positiveOutput, + scaledNegativeOutput, constantOne); rewriter.replaceOp(op, leakyReluBackwardOutput); return success(); @@ -4127,12 +4145,12 @@ class DecomposeAtenRreluWithNoiseBackwardOp if (training && (upper - lower > 0.000001)) { Value rreluWithNoiseBackwardOutput = - rewriter.create(loc, resType, gradOutput, noise); + AtenMulTensorOp::create(rewriter, loc, resType, gradOutput, noise); rewriter.replaceOp(op, rreluWithNoiseBackwardOutput); } else { double negative_slope = (upper + lower) / 2; - Value cstNegativeSlope = rewriter.create( - loc, rewriter.getF64FloatAttr(negative_slope)); + Value cstNegativeSlope = ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(negative_slope)); rewriter.replaceOpWithNewOp( op, resType, gradOutput, self, cstNegativeSlope, op.getSelfIsResult()); @@ -4155,13 +4173,13 @@ class DecomposeAtenPreluOp : public OpRewritePattern { auto boolTensorType = rewriter.getType( resType.getOptionalSizes(), rewriter.getI1Type()); Value zero = - rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); + ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(0.0)); Value inputMulWeight = - rewriter.create(loc, resType, input, weight); + AtenMulTensorOp::create(rewriter, loc, resType, input, weight); Value lessThanZero = - rewriter.create(loc, boolTensorType, input, zero); - Value preluOutput = rewriter.create( - loc, resType, lessThanZero, inputMulWeight, input); + AtenLtScalarOp::create(rewriter, loc, boolTensorType, input, zero); + Value preluOutput = AtenWhereSelfOp::create( + rewriter, loc, resType, lessThanZero, inputMulWeight, input); rewriter.replaceOp(op, preluOutput); return success(); @@ -4194,43 +4212,44 @@ class DecomposeAtenRreluOp : public OpRewritePattern { } Value constantZeroFloat = - rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); + ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(0.0)); Value constantOneFloat = - rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(1.0)); Value constantTwoFloat = - rewriter.create(loc, rewriter.getF64FloatAttr(2.0)); + ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(2.0)); Value alpha; if (training) { // Create a uniform random op with low and high set to `lower` and // `upper`, respectively. - Value none = rewriter.create(loc); - alpha = rewriter.create(loc, resType, self, - /*from=*/lower, /*to=*/upper, - /*generator=*/none); + Value none = ConstantNoneOp::create(rewriter, loc); + alpha = AtenUniformOp::create(rewriter, loc, resType, self, + /*from=*/lower, /*to=*/upper, + /*generator=*/none); } else { - Value half = rewriter.create(loc, constantTwoFloat.getType(), - lower, upper); - alpha = rewriter.create(loc, constantTwoFloat.getType(), half, - constantTwoFloat); + Value half = AtenAddOp::create(rewriter, loc, constantTwoFloat.getType(), + lower, upper); + alpha = AtenDivOp::create(rewriter, loc, constantTwoFloat.getType(), half, + constantTwoFloat); } Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZeroFloat); Value positiveOutput = - rewriter.create(loc, resType, zeroTensor, self); + AtenMaximumOp::create(rewriter, loc, resType, zeroTensor, self); Value scaledSelf; if (training) { - scaledSelf = rewriter.create(loc, resType, self, alpha); + scaledSelf = AtenMulTensorOp::create(rewriter, loc, resType, self, alpha); } else { - scaledSelf = rewriter.create(loc, resType, self, alpha); + scaledSelf = AtenMulScalarOp::create(rewriter, loc, resType, self, alpha); } Value negativeOutput = - rewriter.create(loc, resType, zeroTensor, scaledSelf); - Value rreluOutput = rewriter.create( - loc, resType, positiveOutput, negativeOutput, constantOneFloat); + AtenMinimumOp::create(rewriter, loc, resType, zeroTensor, scaledSelf); + Value rreluOutput = + AtenAddTensorOp::create(rewriter, loc, resType, positiveOutput, + negativeOutput, constantOneFloat); rewriter.replaceOp(op, rreluOutput); return success(); } @@ -4250,14 +4269,13 @@ class DecomposeAtenRreluWithNoiseOp Value lower = op.getLower(); Value upper = op.getUpper(); auto resType = cast(op.getType()); - Value cstNone = rewriter.create(loc); + Value cstNone = ConstantNoneOp::create(rewriter, loc); Value cstFalse = - rewriter.create(loc, rewriter.getBoolAttr(false)); - Value result = - rewriter - .create( - loc, resType, self, noise, lower, upper, cstFalse, cstNone) - ->getResult(0); + ConstantBoolOp::create(rewriter, loc, rewriter.getBoolAttr(false)); + Value result = AtenRreluWithNoiseFunctionalOp::create( + rewriter, loc, resType, self, noise, lower, upper, + cstFalse, cstNone) + ->getResult(0); rewriter.replaceOp(op, result); return success(); } @@ -4287,53 +4305,54 @@ class DecomposeAtenRreluWithNoiseFunctionalOp } Value constantZeroFloat = - rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); + ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(0.0)); Value constantOneFloat = - rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(1.0)); Value constantTwoFloat = - rewriter.create(loc, rewriter.getF64FloatAttr(2.0)); + ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(2.0)); Value alpha; if (training) { - Value none = rewriter.create(loc); - Value emptyTensor = rewriter.create( - loc, resType, self, constantZeroFloat, /*dtype=*/none, + Value none = ConstantNoneOp::create(rewriter, loc); + Value emptyTensor = AtenFullLikeOp::create( + rewriter, loc, resType, self, constantZeroFloat, /*dtype=*/none, /*layout=*/none, /*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none); - alpha = rewriter.create(loc, resType, emptyTensor, - /*from=*/lower, /*to=*/upper, - /*generator=*/none); + alpha = AtenUniformOp::create(rewriter, loc, resType, emptyTensor, + /*from=*/lower, /*to=*/upper, + /*generator=*/none); } else { - Value half = rewriter.create(loc, constantTwoFloat.getType(), - lower, upper); - alpha = rewriter.create(loc, constantTwoFloat.getType(), half, - constantTwoFloat); + Value half = AtenAddOp::create(rewriter, loc, constantTwoFloat.getType(), + lower, upper); + alpha = AtenDivOp::create(rewriter, loc, constantTwoFloat.getType(), half, + constantTwoFloat); } Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZeroFloat); Value positiveOutput = - rewriter.create(loc, resType, zeroTensor, self); + AtenMaximumOp::create(rewriter, loc, resType, zeroTensor, self); Value scaledSelf; if (training) { - scaledSelf = rewriter.create(loc, resType, self, alpha); + scaledSelf = AtenMulTensorOp::create(rewriter, loc, resType, self, alpha); auto boolResType = resType.getWithSizesAndDtype(resType.getSizes(), rewriter.getI1Type()); Value oneTensor = createRank0Tensor(rewriter, loc, resType, constantOneFloat); - Value not_positive = rewriter.create( - loc, boolResType, self, constantZeroFloat); - noise = rewriter.create(loc, resType, not_positive, - alpha, oneTensor); + Value not_positive = AtenLeScalarOp::create(rewriter, loc, boolResType, + self, constantZeroFloat); + noise = AtenWhereSelfOp::create(rewriter, loc, resType, not_positive, + alpha, oneTensor); } else { - scaledSelf = rewriter.create(loc, resType, self, alpha); + scaledSelf = AtenMulScalarOp::create(rewriter, loc, resType, self, alpha); } Value negativeOutput = - rewriter.create(loc, resType, zeroTensor, scaledSelf); - Value rreluOutput = rewriter.create( - loc, resType, positiveOutput, negativeOutput, constantOneFloat); + AtenMinimumOp::create(rewriter, loc, resType, zeroTensor, scaledSelf); + Value rreluOutput = + AtenAddTensorOp::create(rewriter, loc, resType, positiveOutput, + negativeOutput, constantOneFloat); rewriter.replaceOp(op, {rreluOutput, noise}); return success(); } @@ -4356,27 +4375,27 @@ class DecomposeAtenCeluOp : public OpRewritePattern { } Value constantZero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0)); Value constantOne = - rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(1.0)); // positiveOutput = max(0,x) Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero); Value positiveOutput = - rewriter.create(loc, resType, zeroTensor, input); + AtenMaximumOp::create(rewriter, loc, resType, zeroTensor, input); // negativeOutput = min(0,alpha∗(exp(x/alpha)−1)) Value scaledInput = - rewriter.create(loc, resType, input, alpha); - Value expX = rewriter.create(loc, resType, scaledInput); - Value expXM1 = rewriter.create(loc, resType, expX, - constantOne, constantOne); + AtenDivScalarOp::create(rewriter, loc, resType, input, alpha); + Value expX = AtenExpOp::create(rewriter, loc, resType, scaledInput); + Value expXM1 = AtenSubScalarOp::create(rewriter, loc, resType, expX, + constantOne, constantOne); Value scaledExpXM1 = - rewriter.create(loc, resType, expXM1, alpha); + AtenMulScalarOp::create(rewriter, loc, resType, expXM1, alpha); Value negativeOutput = - rewriter.create(loc, resType, zeroTensor, scaledExpXM1); - Value celuOutput = rewriter.create( - loc, resType, positiveOutput, negativeOutput, constantOne); + AtenMinimumOp::create(rewriter, loc, resType, zeroTensor, scaledExpXM1); + Value celuOutput = AtenAddTensorOp::create( + rewriter, loc, resType, positiveOutput, negativeOutput, constantOne); rewriter.replaceOp(op, celuOutput); return success(); @@ -4396,17 +4415,17 @@ class DecomposeAtenLerpScalarOp : public OpRewritePattern { return rewriter.notifyMatchFailure(op, "result should have dtype"); } Value cstOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); auto start = op.getSelf(); auto inputType = cast(start.getType()); - auto delta = rewriter.create(loc, inputType, op.getEnd(), - start, cstOne); + auto delta = AtenSubTensorOp::create(rewriter, loc, inputType, op.getEnd(), + start, cstOne); - auto weightedDelta = - rewriter.create(loc, inputType, delta, op.getWeight()); - auto lerp = rewriter.create(loc, resType, start, - weightedDelta, cstOne); + auto weightedDelta = AtenMulScalarOp::create(rewriter, loc, inputType, + delta, op.getWeight()); + auto lerp = AtenAddTensorOp::create(rewriter, loc, resType, start, + weightedDelta, cstOne); rewriter.replaceOp(op, lerp); return success(); } @@ -4425,17 +4444,17 @@ class DecomposeAtenLerpTensorOp : public OpRewritePattern { return rewriter.notifyMatchFailure(op, "result should have dtype"); } Value cstOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); auto start = op.getSelf(); auto inputType = cast(start.getType()); - auto delta = rewriter.create(loc, inputType, op.getEnd(), - start, cstOne); + auto delta = AtenSubTensorOp::create(rewriter, loc, inputType, op.getEnd(), + start, cstOne); - auto weightedDelta = - rewriter.create(loc, inputType, delta, op.getWeight()); - auto lerp = rewriter.create(loc, resType, start, - weightedDelta, cstOne); + auto weightedDelta = AtenMulTensorOp::create(rewriter, loc, inputType, + delta, op.getWeight()); + auto lerp = AtenAddTensorOp::create(rewriter, loc, resType, start, + weightedDelta, cstOne); rewriter.replaceOp(op, lerp); return success(); } @@ -4460,28 +4479,28 @@ class DecomposeAtenEluOp : public OpRewritePattern { } Value constantZero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0)); Value constantOne = - rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(1.0)); Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero); Value maxZeroX = - rewriter.create(loc, resType, zeroTensor, input); + AtenMaximumOp::create(rewriter, loc, resType, zeroTensor, input); Value positiveOutput = - rewriter.create(loc, resType, maxZeroX, scale); + AtenMulScalarOp::create(rewriter, loc, resType, maxZeroX, scale); Value minZeroX = - rewriter.create(loc, resType, zeroTensor, input); + AtenMinimumOp::create(rewriter, loc, resType, zeroTensor, input); Value scaledMinZeroX = - rewriter.create(loc, resType, minZeroX, inputScale); - Value expX = rewriter.create(loc, resType, scaledMinZeroX); - Value expXM1 = rewriter.create(loc, resType, expX, - constantOne, constantOne); + AtenMulScalarOp::create(rewriter, loc, resType, minZeroX, inputScale); + Value expX = AtenExpOp::create(rewriter, loc, resType, scaledMinZeroX); + Value expXM1 = AtenSubScalarOp::create(rewriter, loc, resType, expX, + constantOne, constantOne); Value scaledExpXM1 = - rewriter.create(loc, resType, expXM1, scale); + AtenMulScalarOp::create(rewriter, loc, resType, expXM1, scale); Value negativeOutput = - rewriter.create(loc, resType, scaledExpXM1, alpha); + AtenMulScalarOp::create(rewriter, loc, resType, scaledExpXM1, alpha); - Value eluOutput = rewriter.create( - loc, resType, positiveOutput, negativeOutput, constantOne); + Value eluOutput = AtenAddTensorOp::create( + rewriter, loc, resType, positiveOutput, negativeOutput, constantOne); rewriter.replaceOp(op, eluOutput); return success(); @@ -4508,34 +4527,34 @@ class DecomposeAtenSeluOp : public OpRewritePattern { double alpha = 1.6732632423543772848170429916717; // Create constants for λ and α - Value scaleVal = rewriter.create( - loc, rewriter.getF64FloatAttr(scale)); - Value alphaVal = rewriter.create( - loc, rewriter.getF64FloatAttr(alpha)); + Value scaleVal = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(scale)); + Value alphaVal = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(alpha)); // Create zero tensor for comparison Value constantZero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0)); Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero); // Calculate positive and negative parts Value constantOne = - rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(1.0)); Value positiveOutput = - rewriter.create(loc, resType, zeroTensor, input); + AtenMaximumOp::create(rewriter, loc, resType, zeroTensor, input); Value minZeroX = - rewriter.create(loc, resType, zeroTensor, input); - Value expInput = rewriter.create(loc, resType, minZeroX); - Value expInputMinusOne = rewriter.create( - loc, resType, expInput, constantOne, constantOne); - Value negativeOutput = rewriter.create( - loc, resType, expInputMinusOne, alphaVal); + AtenMinimumOp::create(rewriter, loc, resType, zeroTensor, input); + Value expInput = AtenExpOp::create(rewriter, loc, resType, minZeroX); + Value expInputMinusOne = AtenSubScalarOp::create( + rewriter, loc, resType, expInput, constantOne, constantOne); + Value negativeOutput = AtenMulScalarOp::create(rewriter, loc, resType, + expInputMinusOne, alphaVal); // Multiply the result by λ - Value seluOutput = rewriter.create( - loc, resType, positiveOutput, negativeOutput, constantOne); + Value seluOutput = AtenAddTensorOp::create( + rewriter, loc, resType, positiveOutput, negativeOutput, constantOne); seluOutput = - rewriter.create(loc, resType, seluOutput, scaleVal); + AtenMulScalarOp::create(rewriter, loc, resType, seluOutput, scaleVal); // Replace the original operation rewriter.replaceOp(op, seluOutput); @@ -4565,9 +4584,9 @@ class DecomposeAtenTOp : public OpRewritePattern { rewriter.replaceOp(op, lhs); else { Value zero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0)); Value one = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); rewriter.replaceOpWithNewOp(op, op.getType(), lhs, zero, one); } @@ -4612,8 +4631,8 @@ class DecomposeAtenStackOp : public OpRewritePattern { .getWithSizesAndDtype( /*optionalSizes=*/std::nullopt, /*optionalDtype=*/nullptr); Type listType = Torch::ListType::get(listElemType); - Value unsqueezedTensorList = rewriter.create( - op.getLoc(), listType, unsqueezedTensors); + Value unsqueezedTensorList = PrimListConstructOp::create( + rewriter, op.getLoc(), listType, unsqueezedTensors); rewriter.replaceOpWithNewOp(op, op.getType(), unsqueezedTensorList, op.getDim()); return success(); @@ -4645,7 +4664,7 @@ class DecomposeAtenHstackOp : public OpRewritePattern { // Check if the tensor is already of rank >= 1. if (*tensorRank < 1) { auto atleast1dTensor = - rewriter.create(loc, tensor.getType(), tensor); + AtenAtleast1dOp::create(rewriter, loc, tensor.getType(), tensor); atleast1dTensors.push_back(atleast1dTensor); } else { atleast1dTensors.push_back(tensor); @@ -4655,18 +4674,18 @@ class DecomposeAtenHstackOp : public OpRewritePattern { // Make Value list from atleast1dTensors variable. auto elemType = cast(atleast1dTensors[0].getType()) .getWithSizesAndDtype(std::nullopt, nullptr); - Value atleast1dTensorList = rewriter.create( - loc, Torch::ListType::get(elemType), atleast1dTensors); + Value atleast1dTensorList = PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(elemType), atleast1dTensors); // Replace hstack with cat operator. if (getTensorRank(atleast1dTensors[0]) == 1) rewriter.replaceOpWithNewOp( op, op.getType(), atleast1dTensorList, - rewriter.create(loc, rewriter.getI64IntegerAttr(0))); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0))); else rewriter.replaceOpWithNewOp( op, op.getType(), atleast1dTensorList, - rewriter.create(loc, rewriter.getI64IntegerAttr(1))); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1))); return success(); } @@ -4707,14 +4726,14 @@ class DecomposeAtenColumnStackOp : public OpRewritePattern { auto newTy = tTy.getWithSizesAndDtype(tSizes, tTy.getDtype()); SmallVector newShapeList; for (auto tSize : tSizes) { - newShapeList.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(tSize))); + newShapeList.push_back(ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(tSize))); } - auto newShape = rewriter.create( - loc, Torch::ListType::get(rewriter.getType()), + auto newShape = PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(rewriter.getType()), newShapeList); Value tensor2d = - rewriter.create(loc, newTy, tensor, newShape); + AtenReshapeOp::create(rewriter, loc, newTy, tensor, newShape); tensors2d.push_back(tensor2d); } else { tensors2d.push_back(tensor); @@ -4723,12 +4742,12 @@ class DecomposeAtenColumnStackOp : public OpRewritePattern { auto elemType = cast(tensors2d[0].getType()) .getWithSizesAndDtype(std::nullopt, nullptr); - Value newTensors = rewriter.create( - loc, Torch::ListType::get(elemType), tensors2d); + Value newTensors = PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(elemType), tensors2d); rewriter.replaceOpWithNewOp( op, op.getType(), newTensors, - rewriter.create(loc, rewriter.getI64IntegerAttr(1))); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1))); return success(); } @@ -4756,11 +4775,11 @@ class DecomposeAtenRollOp : public OpRewritePattern { return op.emitError("list sizes of shifts and dims are not the same"); auto loc = op.getLoc(); - Value constNone = rewriter.create(loc); - Value constZero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - Value constOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); + Value constNone = ConstantNoneOp::create(rewriter, loc); + Value constZero = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(0)); + Value constOne = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(1)); auto self = op.getSelf(); auto selfTy = cast(self.getType()); // roll(input, shift, dim) = cat({ @@ -4768,22 +4787,22 @@ class DecomposeAtenRollOp : public OpRewritePattern { // slice(input, dim, 0, -shift)}, dim) auto imitateRoll = [&](Value input, Value shift, Value dim, int64_t cstDim) { - Value negShift = rewriter.create(loc, shift); + Value negShift = AtenNegIntOp::create(rewriter, loc, shift); ArrayRef inputShape = selfTy.getSizes(); SmallVector sizes; sizes.append(inputShape.begin(), inputShape.end()); sizes[cstDim] = kUnknownSize; Type sliceTy = selfTy.getWithSizesAndDtype(llvm::ArrayRef(sizes), selfTy.getOptionalDtype()); - Value slice0 = rewriter.create( - loc, sliceTy, input, dim, negShift, constNone, constOne); - Value slice1 = rewriter.create( - loc, sliceTy, input, dim, constZero, negShift, constOne); + Value slice0 = AtenSliceTensorOp::create( + rewriter, loc, sliceTy, input, dim, negShift, constNone, constOne); + Value slice1 = AtenSliceTensorOp::create( + rewriter, loc, sliceTy, input, dim, constZero, negShift, constOne); Type listType = Torch::ListType::get(sliceTy); - Value slices = rewriter.create( - loc, listType, llvm::ArrayRef{slice0, slice1}); - return rewriter.create(loc, self.getType(), slices, dim); + Value slices = PrimListConstructOp::create( + rewriter, loc, listType, llvm::ArrayRef{slice0, slice1}); + return AtenCatOp::create(rewriter, loc, self.getType(), slices, dim); }; std::optional maybeRank = getTensorRank(self); if (!maybeRank) @@ -4854,7 +4873,7 @@ class DecomposeAtenRepeatOp : public OpRewritePattern { llvm::SmallVector unsqueezeDims; for (int i = 0; i < batch; ++i) { Value iv = - rewriter.create(loc, rewriter.getI64IntegerAttr(i)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(i)); self = *unsqueezeTensor(rewriter, op, self, iv); selfTy = cast(self.getType()); unsqueezeDims.push_back(i); @@ -4866,7 +4885,7 @@ class DecomposeAtenRepeatOp : public OpRewritePattern { continue; int64_t dim = i + unsqueezeDims.size() - batch; Value iv = - rewriter.create(loc, rewriter.getI64IntegerAttr(dim)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(dim)); self = *unsqueezeTensor(rewriter, op, self, iv); selfTy = cast(self.getType()); unsqueezeDims.push_back(dim); @@ -4888,18 +4907,18 @@ class DecomposeAtenRepeatOp : public OpRewritePattern { int dim = lengths.size(); Value iv = - rewriter.create(loc, rewriter.getI64IntegerAttr(dim)); - Value dimV = rewriter.create(loc, self, /*dim=*/iv); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(dim)); + Value dimV = AtenSizeIntOp::create(rewriter, loc, self, /*dim=*/iv); lengths.push_back(dimV); expandShape.push_back(selfTy.getSizes()[dim]); } // Materialize the broadcast: - Value lengthv = rewriter.create( - loc, ListType::get(rewriter.getType()), lengths); + Value lengthv = PrimListConstructOp::create( + rewriter, loc, ListType::get(rewriter.getType()), lengths); selfTy = rewriter.getType(expandShape, selfTy.getOptionalDtype()); - self = rewriter.create(loc, selfTy, self, lengthv); + self = AtenBroadcastToOp::create(rewriter, loc, selfTy, self, lengthv); auto outShape = cast(op.getResult().getType()).getSizes(); for (int i = batch, s = repeats.size(); i < s; ++i) { @@ -4917,12 +4936,12 @@ class DecomposeAtenRepeatOp : public OpRewritePattern { selfTy = rewriter.getType(flattenShape, selfTy.getOptionalDtype()); Value start = - rewriter.create(loc, rewriter.getI64IntegerAttr(i)); - Value end = rewriter.create( - loc, rewriter.getI64IntegerAttr(i + 1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(i)); + Value end = ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(i + 1)); - self = rewriter.create(loc, selfTy, self, start, - end); + self = AtenFlattenUsingIntsOp::create(rewriter, loc, selfTy, self, start, + end); } rewriter.replaceOp(op, self); @@ -4973,9 +4992,9 @@ class DecomposeAtenRepeatInterleaveSelfIntOp } dimValue = - rewriter.create(loc, rewriter.getI64IntegerAttr(dim)); - Value dimValuePlusOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(dim + 1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(dim)); + Value dimValuePlusOne = ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(dim + 1)); auto unsqueezedInfo = unsqueezeTensor(rewriter, op, self, dimValuePlusOne); if (failed(unsqueezedInfo)) @@ -4984,14 +5003,15 @@ class DecomposeAtenRepeatInterleaveSelfIntOp self = *unsqueezedInfo; Value constMinusOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(-1)); SmallVector expandShapeValueList(inputRank + 1, constMinusOne); - expandShapeValueList[dim + 1] = rewriter.create( - loc, rewriter.getI64IntegerAttr(repeats)); - Value expandShapeList = rewriter.create( - loc, ListType::get(IntType::get(context)), expandShapeValueList); + expandShapeValueList[dim + 1] = ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(repeats)); + Value expandShapeList = PrimListConstructOp::create( + rewriter, loc, ListType::get(IntType::get(context)), + expandShapeValueList); Value constFalse = - rewriter.create(loc, rewriter.getBoolAttr(false)); + ConstantBoolOp::create(rewriter, loc, rewriter.getBoolAttr(false)); SmallVector expandShape(inputRank + 1); for (int64_t i = 0; i <= dim; i++) { @@ -5005,18 +5025,18 @@ class DecomposeAtenRepeatInterleaveSelfIntOp BaseTensorType expandTy = rewriter.getType( expandShape, selfTy.getOptionalDtype()); - Value expandSelf = rewriter.create( - loc, expandTy, self, expandShapeList, constFalse); + Value expandSelf = AtenExpandOp::create(rewriter, loc, expandTy, self, + expandShapeList, constFalse); Value result; if (dimIsNone) { Value constZero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); - result = rewriter.create( - loc, resType, expandSelf, constZero, constMinusOne); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0)); + result = AtenFlattenUsingIntsOp::create( + rewriter, loc, resType, expandSelf, constZero, constMinusOne); } else { - result = rewriter.create(loc, resType, expandSelf, - dimValue, dimValuePlusOne); + result = PrimsCollapseOp::create(rewriter, loc, resType, expandSelf, + dimValue, dimValuePlusOne); } rewriter.replaceOp(op, result); @@ -5051,7 +5071,7 @@ class DecomposeAtenFlattenUsingIntsOp SmallVector newSizes; if (rank == 0) { Value one = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); newSizes.push_back(one); } else { start = toPositiveDim(start, rank); @@ -5065,22 +5085,22 @@ class DecomposeAtenFlattenUsingIntsOp newSizes.reserve(rank - end + start); for (int64_t k = 0; k < start; ++k) { Value dim = - rewriter.create(loc, rewriter.getI64IntegerAttr(k)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(k)); newSizes.push_back( - rewriter.create(loc, self, /*dim=*/dim)); + AtenSizeIntOp::create(rewriter, loc, self, /*dim=*/dim)); } Value flattenDimSize = - rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(-1)); newSizes.push_back(flattenDimSize); for (int64_t k = end + 1; k < rank; ++k) { Value dim = - rewriter.create(loc, rewriter.getI64IntegerAttr(k)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(k)); newSizes.push_back( - rewriter.create(loc, self, /*dim=*/dim)); + AtenSizeIntOp::create(rewriter, loc, self, /*dim=*/dim)); } } - Value newSizeList = rewriter.create( - loc, ListType::get(IntType::get(context)), newSizes); + Value newSizeList = PrimListConstructOp::create( + rewriter, loc, ListType::get(IntType::get(context)), newSizes); rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), newSizeList); return success(); @@ -5148,9 +5168,9 @@ class DecomposeAtenUnflattenIntOp SmallVector newSizes; for (int64_t i = 0; i < inputRank; ++i) { Value dimValue = - rewriter.create(loc, rewriter.getI64IntegerAttr(i)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(i)); Value dimSize = - rewriter.create(loc, self, /*dim=*/dimValue); + AtenSizeIntOp::create(rewriter, loc, self, /*dim=*/dimValue); if (i == dimInt) { int64_t inferredSizeInt = inputShape[i]; int64_t inferredDim; @@ -5159,15 +5179,15 @@ class DecomposeAtenUnflattenIntOp inferred = true; inferredDim = j; } else { - Value sizeValue = rewriter.create( - loc, rewriter.getI64IntegerAttr(sizesInts[j])); + Value sizeValue = ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(sizesInts[j])); newSizes.push_back(sizeValue); inferredSizeInt = inferredSizeInt / sizesInts[j]; } } if (inferred) { - Value inferredSize = rewriter.create( - loc, rewriter.getI64IntegerAttr(inferredSizeInt)); + Value inferredSize = ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(inferredSizeInt)); newSizes.insert(newSizes.begin() + inferredDim + i, inferredSize); } } else { @@ -5176,8 +5196,8 @@ class DecomposeAtenUnflattenIntOp } // Create the AtenViewOp to replace the original op. - Value newSizeList = rewriter.create( - loc, ListType::get(IntType::get(context)), newSizes); + Value newSizeList = PrimListConstructOp::create( + rewriter, loc, ListType::get(IntType::get(context)), newSizes); rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), newSizeList); return success(); @@ -5196,11 +5216,11 @@ class DecomposeAtenUpsampleNearestVecOp Value scales = op.getScaleFactors(); static_assert(std::is_same_v || std::is_same_v); - Value cstMode = rewriter.create( - op.getLoc(), rewriter.getStringAttr("nearest")); - Value cstNone = rewriter.create(op.getLoc()); + Value cstMode = Torch::ConstantStrOp::create( + rewriter, op.getLoc(), rewriter.getStringAttr("nearest")); + Value cstNone = Torch::ConstantNoneOp::create(rewriter, op.getLoc()); Value cstAntialias = - rewriter.create(op.getLoc(), false); + Torch::ConstantBoolOp::create(rewriter, op.getLoc(), false); rewriter.replaceOpWithNewOp( op, op.getType(), op.getInput(), op.getOutputSize(), op.getScaleFactors(), cstMode, cstNone, cstNone, cstAntialias); @@ -5315,34 +5335,36 @@ class DecomposeAtenNanToNumOp : public OpRewritePattern { if (isa(nan.getType())) { nan = - rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); + ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(0.0)); } if (isa(posinf.getType())) { - posinf = rewriter.create( - loc, rewriter.getF64FloatAttr( - APFloat::getLargest(outputElementType.getFloatSemantics()) - .convertToDouble())); + posinf = ConstantFloatOp::create( + rewriter, loc, + rewriter.getF64FloatAttr( + APFloat::getLargest(outputElementType.getFloatSemantics()) + .convertToDouble())); } if (isa(neginf.getType())) { - neginf = rewriter.create( - loc, rewriter.getF64FloatAttr( - APFloat::getLargest(outputElementType.getFloatSemantics(), - /*Negative=*/true) - .convertToDouble())); + neginf = ConstantFloatOp::create( + rewriter, loc, + rewriter.getF64FloatAttr( + APFloat::getLargest(outputElementType.getFloatSemantics(), + /*Negative=*/true) + .convertToDouble())); } auto compareType = outputType.getWithSizesAndDtype( outputType.getOptionalSizes(), rewriter.getI1Type()); Value isNan = - rewriter.create(loc, compareType, op.getSelf()); - Value where = rewriter.create( - loc, outputType, isNan, nan, op.getSelf()); + Torch::AtenIsnanOp::create(rewriter, loc, compareType, op.getSelf()); + Value where = Torch::AtenWhereScalarSelfOp::create( + rewriter, loc, outputType, isNan, nan, op.getSelf()); Value isposinf = - rewriter.create(loc, compareType, where); - where = rewriter.create( - loc, outputType, isposinf, posinf, where); + Torch::AtenIsposinfOp::create(rewriter, loc, compareType, where); + where = Torch::AtenWhereScalarSelfOp::create(rewriter, loc, outputType, + isposinf, posinf, where); Value isneginf = - rewriter.create(loc, compareType, where); + Torch::AtenIsneginfOp::create(rewriter, loc, compareType, where); rewriter.replaceOpWithNewOp( op, outputType, isneginf, neginf, where); return success(); @@ -5433,15 +5455,15 @@ class DecomposeAtenMaskedScatterOp int64_t selfRank = selfTy.getSizes().size(); int64_t sourceRank = sourceTy.getSizes().size(); - Value constZero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - Value constOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - Value constNone = rewriter.create(loc); - Value selfLastDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(selfRank - 1)); - Value sourceLastDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(sourceRank - 1)); + Value constZero = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(0)); + Value constOne = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(1)); + Value constNone = ConstantNoneOp::create(rewriter, loc); + Value selfLastDim = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(selfRank - 1)); + Value sourceLastDim = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(sourceRank - 1)); auto si64Type = IntegerType::get(context, 64, IntegerType::Signed); auto int64Dtype = getDtypeIntValueForType( @@ -5449,40 +5471,40 @@ class DecomposeAtenMaskedScatterOp rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true)); auto selfIntType = selfTy.getWithSizesAndDtype(selfTy.getSizes(), si64Type); - Value zerosLike = rewriter.create( - loc, selfIntType, self, int64Dtype, constNone, constNone, constNone, - constNone); - Value maskInt = rewriter.create( - loc, selfIntType, mask, zerosLike, constOne); + Value zerosLike = Torch::AtenZerosLikeOp::create( + rewriter, loc, selfIntType, self, int64Dtype, constNone, constNone, + constNone, constNone); + Value maskInt = Torch::AtenAddTensorOp::create(rewriter, loc, selfIntType, + mask, zerosLike, constOne); auto flattenMaskedType = selfTy.getWithSizesAndDtype( /*optionalSizes=*/{selfNumel}, si64Type); - Value maskIntFlatten = rewriter.create( - loc, flattenMaskedType, maskInt, constZero, selfLastDim); - Value prefixSum = rewriter.create( - loc, flattenMaskedType, maskIntFlatten, + Value maskIntFlatten = Torch::AtenFlattenUsingIntsOp::create( + rewriter, loc, flattenMaskedType, maskInt, constZero, selfLastDim); + Value prefixSum = Torch::AtenCumsumOp::create( + rewriter, loc, flattenMaskedType, maskIntFlatten, /*dim=*/constZero, constNone); - Value prefixSumMinusOne = rewriter.create( - loc, flattenMaskedType, prefixSum, constOne, constOne); - Value maskPrefix = rewriter.create( - loc, flattenMaskedType, prefixSumMinusOne, /*min=*/constZero, + Value prefixSumMinusOne = Torch::AtenSubScalarOp::create( + rewriter, loc, flattenMaskedType, prefixSum, constOne, constOne); + Value maskPrefix = Torch::AtenClampOp::create( + rewriter, loc, flattenMaskedType, prefixSumMinusOne, /*min=*/constZero, /*max=*/constNone); auto sourceFlattenType = sourceTy.getWithSizesAndDtype( /*optionalSizes=*/{sourceNumel}, sourceTy.getDtype()); - Value sourceFlatten = rewriter.create( - loc, sourceFlattenType, source, constZero, sourceLastDim); + Value sourceFlatten = Torch::AtenFlattenUsingIntsOp::create( + rewriter, loc, sourceFlattenType, source, constZero, sourceLastDim); auto selectSourceType = sourceTy.getWithSizesAndDtype( /*optionalSizes=*/{selfNumel}, sourceTy.getDtype()); - Value selectSource = rewriter.create( - loc, selectSourceType, sourceFlatten, constZero, maskPrefix); + Value selectSource = Torch::AtenIndexSelectOp::create( + rewriter, loc, selectSourceType, sourceFlatten, constZero, maskPrefix); // Reshape normalized output back to the original input shape - auto selfShape = rewriter.create( - loc, Torch::ListType::get(IntType::get(context)), self); - Value sourceReshape = rewriter.create( - loc, selfTy, selectSource, selfShape); + auto selfShape = AtenSizeOp::create( + rewriter, loc, Torch::ListType::get(IntType::get(context)), self); + Value sourceReshape = Torch::AtenViewOp::create(rewriter, loc, selfTy, + selectSource, selfShape); rewriter.replaceOpWithNewOp(op, resTy, mask, sourceReshape, self); return success(); @@ -5521,12 +5543,12 @@ static LogicalResult createTorchTransposeOpForConvTbc(PatternRewriter &rewriter, if (failed(getTransposedType(cast(input.getType()), dimA, dimB, transposedType))) return failure(); - Value cstDimA = rewriter.create( - loc, rewriter.getI64IntegerAttr(dimA)); - Value cstDimB = rewriter.create( - loc, rewriter.getI64IntegerAttr(dimB)); - transposed = rewriter.create( - loc, transposedType, input, cstDimA, cstDimB); + Value cstDimA = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(dimA)); + Value cstDimB = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(dimB)); + transposed = Torch::AtenTransposeIntOp::create(rewriter, loc, transposedType, + input, cstDimA, cstDimB); return success(); } @@ -5535,19 +5557,23 @@ class DecomposeAtenConvTbcOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenConvTbcOp op, PatternRewriter &rewriter) const override { - Value emptyList = rewriter.create( - op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), + Value emptyList = PrimListConstructOp::create( + rewriter, op.getLoc(), + Torch::ListType::get(Torch::IntType::get(op.getContext())), SmallVector()); - Value cstFalse = rewriter.create(op.getLoc(), false); - Value oneList = rewriter.create( - op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), - SmallVector{rewriter.create( - op.getLoc(), rewriter.getI64IntegerAttr(1))}); - Value padding = rewriter.create( - op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), + Value cstFalse = + Torch::ConstantBoolOp::create(rewriter, op.getLoc(), false); + Value oneList = PrimListConstructOp::create( + rewriter, op.getLoc(), + Torch::ListType::get(Torch::IntType::get(op.getContext())), + SmallVector{Torch::ConstantIntOp::create( + rewriter, op.getLoc(), rewriter.getI64IntegerAttr(1))}); + Value padding = PrimListConstructOp::create( + rewriter, op.getLoc(), + Torch::ListType::get(Torch::IntType::get(op.getContext())), SmallVector{op.getPad()}); - Value groups = rewriter.create( - op.getLoc(), rewriter.getI64IntegerAttr(1)); + Value groups = Torch::ConstantIntOp::create(rewriter, op.getLoc(), + rewriter.getI64IntegerAttr(1)); // convtbc has WNC layout for input and output // and WCF layout for weight @@ -5572,8 +5598,9 @@ class DecomposeAtenConvTbcOp : public OpRewritePattern { return rewriter.notifyMatchFailure(op, "failed to transpose weight to Fcw"); - Value outputNcw = rewriter.create( - op.getLoc(), op->getResultTypes(), selfNcw, weightFcw, op.getBias(), + Value outputNcw = AtenConvolutionOp::create( + rewriter, op.getLoc(), op->getResultTypes(), selfNcw, weightFcw, + op.getBias(), /*stride*/ oneList, /*padding*/ padding, /*dilation*/ oneList, /*transpose*/ cstFalse, /*output_padding*/ emptyList, groups); @@ -5604,10 +5631,12 @@ class DecomposeAtenConv1dOp : public OpRewritePattern { LogicalResult matchAndRewrite(AtenConv1dOp op, PatternRewriter &rewriter) const override { - Value emptyList = rewriter.create( - op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), + Value emptyList = PrimListConstructOp::create( + rewriter, op.getLoc(), + Torch::ListType::get(Torch::IntType::get(op.getContext())), SmallVector()); - Value cstFalse = rewriter.create(op.getLoc(), false); + Value cstFalse = + Torch::ConstantBoolOp::create(rewriter, op.getLoc(), false); rewriter.replaceOpWithNewOp( op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(), op.getStride(), op.getPadding(), op.getDilation(), cstFalse, emptyList, @@ -5626,10 +5655,12 @@ class DecomposeAtenConv2dOp : public OpRewritePattern { LogicalResult matchAndRewrite(AtenConv2dOp op, PatternRewriter &rewriter) const override { - Value emptyList = rewriter.create( - op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), + Value emptyList = PrimListConstructOp::create( + rewriter, op.getLoc(), + Torch::ListType::get(Torch::IntType::get(op.getContext())), SmallVector()); - Value cstFalse = rewriter.create(op.getLoc(), false); + Value cstFalse = + Torch::ConstantBoolOp::create(rewriter, op.getLoc(), false); rewriter.replaceOpWithNewOp( op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(), op.getStride(), op.getPadding(), op.getDilation(), cstFalse, emptyList, @@ -5667,8 +5698,8 @@ class DecomposeAtenConvPaddingOp : public OpRewritePattern { return rewriter.notifyMatchFailure(op, "padding must be a constant string"); - Value zero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); + Value zero = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(0)); SmallVector paddingValues; if (padding_str == "valid") { @@ -5682,29 +5713,32 @@ class DecomposeAtenConvPaddingOp : public OpRewritePattern { getListConstructElements(op.getDilation(), dilation); Value one = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); Value two = - rewriter.create(loc, rewriter.getI64IntegerAttr(2)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(2)); for (unsigned iRank = 2; iRank < rank; iRank++) { - Value dim = rewriter.create( - loc, rewriter.getI64IntegerAttr(iRank)); + Value dim = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(iRank)); Value kernelSize = - rewriter.create(loc, weight, dim); + Torch::AtenSizeIntOp::create(rewriter, loc, weight, dim); Value kernelSizeMinusOne = - rewriter.create(loc, kernelSize, one); - Value padding = rewriter.create( - loc, dilation[iRank - 2], kernelSizeMinusOne); - padding = rewriter.create(loc, padding, two); + Torch::AtenSubIntOp::create(rewriter, loc, kernelSize, one); + Value padding = Torch::AtenMulIntOp::create( + rewriter, loc, dilation[iRank - 2], kernelSizeMinusOne); + padding = AtenFloordivIntOp::create(rewriter, loc, padding, two); paddingValues.push_back(padding); } } - Value emptyList = rewriter.create( - op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), + Value emptyList = PrimListConstructOp::create( + rewriter, op.getLoc(), + Torch::ListType::get(Torch::IntType::get(op.getContext())), SmallVector()); - Value cstFalse = rewriter.create(op.getLoc(), false); - Value padding = rewriter.create( - op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), + Value cstFalse = + Torch::ConstantBoolOp::create(rewriter, op.getLoc(), false); + Value padding = PrimListConstructOp::create( + rewriter, op.getLoc(), + Torch::ListType::get(Torch::IntType::get(op.getContext())), paddingValues); rewriter.replaceOpWithNewOp( op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(), @@ -5724,10 +5758,12 @@ class DecomposeAtenConv3dOp : public OpRewritePattern { LogicalResult matchAndRewrite(AtenConv3dOp op, PatternRewriter &rewriter) const override { - Value emptyList = rewriter.create( - op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), + Value emptyList = PrimListConstructOp::create( + rewriter, op.getLoc(), + Torch::ListType::get(Torch::IntType::get(op.getContext())), SmallVector()); - Value cstFalse = rewriter.create(op.getLoc(), false); + Value cstFalse = + Torch::ConstantBoolOp::create(rewriter, op.getLoc(), false); rewriter.replaceOpWithNewOp( op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(), op.getStride(), op.getPadding(), op.getDilation(), cstFalse, emptyList, @@ -5747,7 +5783,7 @@ class DecomposeAtenConvTranspose1dOp LogicalResult matchAndRewrite(AtenConvTranspose1dOp op, PatternRewriter &rewriter) const override { - Value cstTrue = rewriter.create(op.getLoc(), true); + Value cstTrue = Torch::ConstantBoolOp::create(rewriter, op.getLoc(), true); rewriter.replaceOpWithNewOp( op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(), op.getStride(), op.getPadding(), op.getDilation(), @@ -5766,7 +5802,7 @@ class DecomposeAtenConvTranspose2dOp LogicalResult matchAndRewrite(AtenConvTranspose2dInputOp op, PatternRewriter &rewriter) const override { - Value cstTrue = rewriter.create(op.getLoc(), true); + Value cstTrue = Torch::ConstantBoolOp::create(rewriter, op.getLoc(), true); rewriter.replaceOpWithNewOp( op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(), op.getStride(), op.getPadding(), op.getDilation(), @@ -5785,7 +5821,7 @@ class DecomposeAtenConvTranspose3dOp LogicalResult matchAndRewrite(AtenConvTranspose3dInputOp op, PatternRewriter &rewriter) const override { - Value cstTrue = rewriter.create(op.getLoc(), true); + Value cstTrue = Torch::ConstantBoolOp::create(rewriter, op.getLoc(), true); rewriter.replaceOpWithNewOp( op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(), op.getStride(), op.getPadding(), op.getDilation(), @@ -5871,15 +5907,15 @@ class DecomposeAtenConvolutionBackwardOp return rewriter.notifyMatchFailure( op, "unimplemented: only 2D convolutions supported."); - Value cstZero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - Value cstOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - Value cstTwo = rewriter.create( - loc, rewriter.getI64IntegerAttr(2)); - Value cstNone = rewriter.create(loc); - Value cstFalse = rewriter.create( - loc, rewriter.getBoolAttr(false)); + Value cstZero = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(0)); + Value cstOne = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(1)); + Value cstTwo = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(2)); + Value cstNone = Torch::ConstantNoneOp::create(rewriter, loc); + Value cstFalse = Torch::ConstantBoolOp::create(rewriter, loc, + rewriter.getBoolAttr(false)); SmallVector padding, dilation, stride; SmallVector paddingInt, dilationInt, strideInt, @@ -5953,39 +5989,39 @@ class DecomposeAtenConvolutionBackwardOp // ] SmallVector outputPaddingValues; for (unsigned i = 2; i < gradRank; i++) { - Value dim = rewriter.create( - loc, rewriter.getI64IntegerAttr(i)); + Value dim = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(i)); Value inputVecDim = - rewriter.create(loc, input, dim); + Torch::AtenSizeIntOp::create(rewriter, loc, input, dim); Value gradOutDim = - rewriter.create(loc, gradOutput, dim); + Torch::AtenSizeIntOp::create(rewriter, loc, gradOutput, dim); Value weightDim = - rewriter.create(loc, weight, dim); + Torch::AtenSizeIntOp::create(rewriter, loc, weight, dim); Value inputVecDimMinusOne = - rewriter.create(loc, inputVecDim, cstOne); + Torch::AtenSubIntOp::create(rewriter, loc, inputVecDim, cstOne); Value gradOutDimMinusOne = - rewriter.create(loc, gradOutDim, cstOne); + Torch::AtenSubIntOp::create(rewriter, loc, gradOutDim, cstOne); Value weightDimMinusOne = - rewriter.create(loc, weightDim, cstOne); + Torch::AtenSubIntOp::create(rewriter, loc, weightDim, cstOne); Value twoTimesPadding = - rewriter.create(loc, padding[i - 2], cstTwo); - Value tmpA = rewriter.create( - loc, weightDimMinusOne, dilation[i - 2]); - Value tmpB = rewriter.create( - loc, gradOutDimMinusOne, stride[i - 2]); - Value outputPaddingVal = rewriter.create( - loc, inputVecDimMinusOne, twoTimesPadding); + Torch::AtenMulIntOp::create(rewriter, loc, padding[i - 2], cstTwo); + Value tmpA = Torch::AtenMulIntOp::create( + rewriter, loc, weightDimMinusOne, dilation[i - 2]); + Value tmpB = Torch::AtenMulIntOp::create( + rewriter, loc, gradOutDimMinusOne, stride[i - 2]); + Value outputPaddingVal = AtenAddIntOp::create( + rewriter, loc, inputVecDimMinusOne, twoTimesPadding); outputPaddingVal = - rewriter.create(loc, outputPaddingVal, tmpA); + AtenSubIntOp::create(rewriter, loc, outputPaddingVal, tmpA); outputPaddingVal = - rewriter.create(loc, outputPaddingVal, tmpB); + AtenSubIntOp::create(rewriter, loc, outputPaddingVal, tmpB); outputPaddingValues.push_back(outputPaddingVal); } - Value outputPaddingForGradInput = - rewriter.create( - loc, ListType::get(IntType::get(context)), outputPaddingValues); - gradInput = rewriter.create( - loc, op.getResultTypes()[0], gradOutput, weight, cstNone, + Value outputPaddingForGradInput = Torch::PrimListConstructOp::create( + rewriter, loc, ListType::get(IntType::get(context)), + outputPaddingValues); + gradInput = Torch::AtenConvTranspose2dInputOp::create( + rewriter, loc, op.getResultTypes()[0], gradOutput, weight, cstNone, op.getStride(), op.getPadding(), outputPaddingForGradInput, op.getGroups(), op.getDilation()); } @@ -5997,8 +6033,8 @@ class DecomposeAtenConvolutionBackwardOp if (failed(getTransposedType(cast(input.getType()), 0, 1, transposedType))) return failure(); - Value inputTransposed = rewriter.create( - loc, transposedType, input, cstZero, cstOne); + Value inputTransposed = Torch::AtenTransposeIntOp::create( + rewriter, loc, transposedType, input, cstZero, cstOne); // For the cases where the stride is non-unit, we compute the `GradWeight` // through this implementation. @@ -6006,19 +6042,19 @@ class DecomposeAtenConvolutionBackwardOp [](int64_t stride) { return stride == 1; })) { SmallVector gradOutputSize; for (unsigned i = 0; i < gradRank; i++) { - gradOutputSize.push_back(rewriter.create( - loc, gradOutput, - rewriter.create( - loc, rewriter.getI64IntegerAttr(i)))); + gradOutputSize.push_back(Torch::AtenSizeIntOp::create( + rewriter, loc, gradOutput, + Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(i)))); } - Value gradOutputViewDimZero = rewriter.create( - loc, gradOutputSize[0], gradOutputSize[1]); - Value gradOutputViewShapeList = - rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), - ValueRange{gradOutputViewDimZero, cstOne, gradOutputSize[2], - gradOutputSize[3]}); + Value gradOutputViewDimZero = Torch::AtenMulIntOp::create( + rewriter, loc, gradOutputSize[0], gradOutputSize[1]); + Value gradOutputViewShapeList = Torch::PrimListConstructOp::create( + rewriter, loc, + Torch::ListType::get(Torch::IntType::get(op.getContext())), + ValueRange{gradOutputViewDimZero, cstOne, gradOutputSize[2], + gradOutputSize[3]}); BaseTensorType gradOutputTy = cast(gradOutput.getType()); @@ -6036,8 +6072,9 @@ class DecomposeAtenConvolutionBackwardOp cast(gradOutputTy.getWithSizesAndDtype( llvm::ArrayRef(gradOutputViewSizesInt), gradOutputTy.getOptionalDtype())); - Value gradOutputView = rewriter.create( - loc, gradOutputTypeForView, gradOutput, gradOutputViewShapeList); + Value gradOutputView = + Torch::AtenViewOp::create(rewriter, loc, gradOutputTypeForView, + gradOutput, gradOutputViewShapeList); BaseTensorType inputTransposedTy = cast(inputTransposed.getType()); @@ -6067,9 +6104,10 @@ class DecomposeAtenConvolutionBackwardOp llvm::ArrayRef(gradWeightSizesInt), inputTransposedTy.getOptionalDtype())); - Value numGroup = rewriter.create(loc, input, cstZero); - gradWeight = rewriter.create( - loc, gradWeightTy, inputTransposed, gradOutputView, cstNone, + Value numGroup = AtenSizeIntOp::create(rewriter, loc, input, cstZero); + gradWeight = Torch::AtenConvolutionOp::create( + rewriter, loc, gradWeightTy, inputTransposed, gradOutputView, + cstNone, /*stride=*/op.getDilation(), op.getPadding(), /*dilation=*/op.getStride(), op.getTransposed(), op.getOutputPadding(), numGroup); @@ -6085,13 +6123,13 @@ class DecomposeAtenConvolutionBackwardOp llvm::ArrayRef(gradWeightSizesInt), gradWeightTy.getOptionalDtype())); - Value dim = rewriter.create( - loc, rewriter.getI64IntegerAttr(i + 2)); + Value dim = ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(i + 2)); Value length = - rewriter.create(loc, weight, dim); - gradWeight = rewriter.create( - loc, gradWeightNarrowTy, gradWeight, dim, /*start=*/cstZero, - length); + Torch::AtenSizeIntOp::create(rewriter, loc, weight, dim); + gradWeight = Torch::AtenNarrowOp::create( + rewriter, loc, gradWeightNarrowTy, gradWeight, dim, + /*start=*/cstZero, length); } SmallVector gradWeightViewShapeInt{ @@ -6103,22 +6141,23 @@ class DecomposeAtenConvolutionBackwardOp SmallVector gradWeightViewShapeValue; for (unsigned i = 0; i < gradWeightViewShapeInt.size(); i++) { - gradWeightViewShapeValue.push_back( - rewriter.create( - loc, rewriter.getI64IntegerAttr(gradWeightViewShapeInt[i]))); + gradWeightViewShapeValue.push_back(Torch::ConstantIntOp::create( + rewriter, loc, + rewriter.getI64IntegerAttr(gradWeightViewShapeInt[i]))); } - Value gradWeightViewShapeList = - rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), - gradWeightViewShapeValue); + Value gradWeightViewShapeList = Torch::PrimListConstructOp::create( + rewriter, loc, + Torch::ListType::get(Torch::IntType::get(op.getContext())), + gradWeightViewShapeValue); BaseTensorType gradWeightTypeForView = cast(gradWeightTy.getWithSizesAndDtype( llvm::ArrayRef(gradWeightViewShapeInt), gradWeightTy.getOptionalDtype())); - gradWeight = rewriter.create( - loc, gradWeightTypeForView, gradWeight, gradWeightViewShapeList); + gradWeight = + Torch::AtenViewOp::create(rewriter, loc, gradWeightTypeForView, + gradWeight, gradWeightViewShapeList); gradWeightTy = cast(gradWeight.getType()); SmallVector gradWeightDimsOrder = @@ -6133,15 +6172,17 @@ class DecomposeAtenConvolutionBackwardOp llvm::ArrayRef(gradWeightMoveDimShape), gradWeightTy.getOptionalDtype())); - gradWeight = rewriter.create( - loc, gradWeightTypeForMoveDim, gradWeight, /*source=*/cstZero, - /*destination=*/cstTwo); + gradWeight = + AtenMovedimIntOp::create(rewriter, loc, gradWeightTypeForMoveDim, + gradWeight, /*source=*/cstZero, + /*destination=*/cstTwo); - Value gradIntList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), + Value gradIntList = Torch::PrimListConstructOp::create( + rewriter, loc, + Torch::ListType::get(Torch::IntType::get(op.getContext())), llvm::ArrayRef{cstZero}); - gradWeight = rewriter.create( - loc, op.getResultTypes()[1], /*self=*/gradWeight, + gradWeight = Torch::AtenSumDimIntListOp::create( + rewriter, loc, op.getResultTypes()[1], /*self=*/gradWeight, /*dim=*/gradIntList, /*keepdim=*/cstFalse, /*dtype=*/cstNone); @@ -6149,19 +6190,20 @@ class DecomposeAtenConvolutionBackwardOp if (failed(getTransposedType(cast(gradOutput.getType()), 0, 1, transposedType))) return failure(); - Value gradOutputTransposed = rewriter.create( - loc, transposedType, gradOutput, cstZero, cstOne); + Value gradOutputTransposed = Torch::AtenTransposeIntOp::create( + rewriter, loc, transposedType, gradOutput, cstZero, cstOne); // Convolve input with grad_output. if (failed( getTransposedType(cast(op.getResultTypes()[1]), 0, 1, transposedType))) return failure(); - gradWeight = rewriter.create( - loc, transposedType, inputTransposed, gradOutputTransposed, cstNone, - op.getStride(), op.getPadding(), op.getDilation(), - op.getTransposed(), op.getOutputPadding(), op.getGroups()); - gradWeight = rewriter.create( - loc, op.getResultTypes()[1], gradWeight, cstZero, cstOne); + gradWeight = Torch::AtenConvolutionOp::create( + rewriter, loc, transposedType, inputTransposed, + gradOutputTransposed, cstNone, op.getStride(), op.getPadding(), + op.getDilation(), op.getTransposed(), op.getOutputPadding(), + op.getGroups()); + gradWeight = Torch::AtenTransposeIntOp::create( + rewriter, loc, op.getResultTypes()[1], gradWeight, cstZero, cstOne); } } @@ -6170,16 +6212,17 @@ class DecomposeAtenConvolutionBackwardOp // Computing Grad Bias. SmallVector dimIntList{cstZero}; for (unsigned i = 2; i < gradRank; i++) - dimIntList.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(i))); - Value gradIntList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), + dimIntList.push_back(Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(i))); + Value gradIntList = Torch::PrimListConstructOp::create( + rewriter, loc, + Torch::ListType::get(Torch::IntType::get(op.getContext())), dimIntList); // Sum grad_output along dim 1. - gradBias = rewriter.create( - loc, op.getResultTypes()[2], gradOutput, gradIntList, cstFalse, - cstNone); + gradBias = Torch::AtenSumDimIntListOp::create( + rewriter, loc, op.getResultTypes()[2], gradOutput, gradIntList, + cstFalse, cstNone); } rewriter.replaceOp(op, {gradInput, gradWeight, gradBias}); @@ -6241,13 +6284,13 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern { auto intType = resultType.getDtype(); Value intTypeValue = getDtypeIntValueForType(rewriter, loc, intType); auto constantZero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0)); auto constantOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); std::function makeOneElementList = [&](Value element) { auto listType = Torch::ListType::get(element.getType()); - return rewriter.create(loc, listType, - ArrayRef{element}); + return PrimListConstructOp::create(rewriter, loc, listType, + ArrayRef{element}); }; Value input = op.getSelf(); @@ -6269,73 +6312,75 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern { flattendInputShape, inputType.getOptionalDtype()); // %1 = torch.aten.flatten.using_ints %arg0, %int0, %int0_0 : - auto inputDimsEnd = rewriter.create( - loc, rewriter.getI64IntegerAttr(inputRank - 1)); - Value flattenedInput = rewriter.create( - loc, flattenedInputType, input, constantZero /*inputDimsStart*/, - inputDimsEnd /*inputDimsEnd*/); + auto inputDimsEnd = ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(inputRank - 1)); + Value flattenedInput = AtenFlattenUsingIntsOp::create( + rewriter, loc, flattenedInputType, input, + constantZero /*inputDimsStart*/, inputDimsEnd /*inputDimsEnd*/); // nonzero_mask = (t_flat != 0) auto boolMaskType = inputType.getWithSizesAndDtype( flattenedInputType.getOptionalSizes(), rewriter.getI1Type()); - Value boolMask = rewriter.create( - loc, boolMaskType, flattenedInput, constantZero); + Value boolMask = AtenNeScalarOp::create(rewriter, loc, boolMaskType, + flattenedInput, constantZero); // nonzero_mask = nonzero_mask.int() - Value falseCst = rewriter.create(loc, false); - Value noneCst = rewriter.create(loc); + Value falseCst = ConstantBoolOp::create(rewriter, loc, false); + Value noneCst = ConstantNoneOp::create(rewriter, loc); auto intMaskType = flattenedInputType.getWithSizesAndDtype( flattenedInputType.getOptionalSizes(), intType); - Value intMask = rewriter.create( - loc, intMaskType, boolMask, intTypeValue, falseCst, falseCst, noneCst); + Value intMask = + AtenToDtypeOp::create(rewriter, loc, intMaskType, boolMask, + intTypeValue, falseCst, falseCst, noneCst); // destination_indices = torch.cumsum(nonzero_mask, 0) - 1 - Value cumulativeSum = rewriter.create( - loc, intMaskType, intMask, constantZero, noneCst); - Value subtracted = rewriter.create( - loc, intMaskType, cumulativeSum, constantOne, /*alpha=*/constantOne); + Value cumulativeSum = AtenCumsumOp::create(rewriter, loc, intMaskType, + intMask, constantZero, noneCst); + Value subtracted = + AtenSubScalarOp::create(rewriter, loc, intMaskType, cumulativeSum, + constantOne, /*alpha=*/constantOne); // destination_indices = torch.clamp(destination_indices, min=0) - Value indices = rewriter.create(loc, intMaskType, - subtracted, constantZero); + Value indices = AtenClampMinOp::create(rewriter, loc, intMaskType, + subtracted, constantZero); // iota = torch.arange(len(t_flat)) * nonzero_mask - Value end = rewriter.create(loc, flattenedInput, - /*dim=*/constantZero); - Value rangeTensor = rewriter.create( - loc, intMaskType, /*start*/ constantZero, /*end*/ end, + Value end = AtenSizeIntOp::create(rewriter, loc, flattenedInput, + /*dim=*/constantZero); + Value rangeTensor = AtenArangeStartStepOp::create( + rewriter, loc, intMaskType, /*start*/ constantZero, /*end*/ end, /*step*/ constantOne, noneCst, noneCst, noneCst, noneCst); - Value multiplied = rewriter.create(loc, intMaskType, - rangeTensor, intMask); + Value multiplied = AtenMulTensorOp::create(rewriter, loc, intMaskType, + rangeTensor, intMask); // scatter_self = torch.zeros_like(t, dtype=torch.int64) // AtenFullLike doesn't support index type so we have to use int. - Value zerosTensor = rewriter.create( - loc, intMaskType, flattenedInput, intTypeValue, noneCst, noneCst, - noneCst, noneCst); + Value zerosTensor = AtenZerosLikeOp::create( + rewriter, loc, intMaskType, flattenedInput, intTypeValue, noneCst, + noneCst, noneCst, noneCst); // compacted = torch.scatter_add( // scatter_self, dim=0, index=destination_indices_clamp, src=iota) - Value scatteredTensor = rewriter.create( - loc, intMaskType, /*self*/ zerosTensor, /*dim=*/constantZero, + Value scatteredTensor = AtenScatterAddOp::create( + rewriter, loc, intMaskType, /*self*/ zerosTensor, /*dim=*/constantZero, /*index=*/indices, /*src=*/multiplied); // result_flat = compacted[:torch.sum(nonzero_mask)] auto scalarType = ValueTensorType::get(rewriter.getContext(), ArrayRef{}, intType); Value sumMask = - rewriter.create(loc, scalarType, intMask, noneCst); - Value numNonzero = rewriter.create(loc, sumMask); + AtenSumOp::create(rewriter, loc, scalarType, intMask, noneCst); + Value numNonzero = AtenIntTensorOp::create(rewriter, loc, sumMask); auto slicedResultType = Torch::ValueTensorType::get( rewriter.getContext(), SmallVector{kUnknownSize}, intType); Value slicedResult = - rewriter.create(loc, slicedResultType, - /*self=*/scatteredTensor, - /*dim=*/constantZero, - /*start=*/noneCst, - /*end=*/numNonzero, - /*step=*/constantOne); + AtenSliceTensorOp::create(rewriter, loc, slicedResultType, + /*self=*/scatteredTensor, + /*dim=*/constantZero, + /*start=*/noneCst, + /*end=*/numNonzero, + /*step=*/constantOne); // TODO fix multidim dynamic support. The following code only work for // static multidim. Convert flattened indices back to multi-dimensional @@ -6346,30 +6391,33 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern { SmallVector shapeValues; for (int i = 0; i < inputRank; i++) { auto constantI = - rewriter.create(loc, rewriter.getI64IntegerAttr(i)); - Value shape = rewriter.create(loc, input, - /*dim=*/constantI); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(i)); + Value shape = AtenSizeIntOp::create(rewriter, loc, input, + /*dim=*/constantI); shapeValues.push_back(shape); } - Value shapeTensorList = rewriter.create( - loc, Torch::ListType::get(shapeValues[0].getType()), shapeValues); - Value inputShapeTensor = rewriter.create( - loc, shapeType, shapeTensorList, noneCst, noneCst, falseCst); + Value shapeTensorList = Torch::PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(shapeValues[0].getType()), + shapeValues); + Value inputShapeTensor = Torch::AtenTensorOp::create( + rewriter, loc, shapeType, shapeTensorList, noneCst, noneCst, falseCst); // strides = torch.cumprod(torch.flip(input_shape_tensor,[0]),0).flip(0) - Value flippedShape = rewriter.create( - loc, shapeType, inputShapeTensor, makeOneElementList(constantZero)); - Value cumulativeProduct = rewriter.create( - loc, shapeType, flippedShape, constantZero, noneCst); - Value flippedCumulativeProduct = rewriter.create( - loc, shapeType, cumulativeProduct, makeOneElementList(constantZero)); + Value flippedShape = + AtenFlipOp::create(rewriter, loc, shapeType, inputShapeTensor, + makeOneElementList(constantZero)); + Value cumulativeProduct = AtenCumprodOp::create( + rewriter, loc, shapeType, flippedShape, constantZero, noneCst); + Value flippedCumulativeProduct = + AtenFlipOp::create(rewriter, loc, shapeType, cumulativeProduct, + makeOneElementList(constantZero)); // strides = torch.cat([strides[1:-1], torch.tensor([1])]) auto oneTensorType = ValueTensorType::get(rewriter.getContext(), SmallVector{1}, intType); - Value oneTensor = rewriter.create( - loc, oneTensorType, constantOne, intTypeValue, noneCst, noneCst, - noneCst); + Value oneTensor = + AtenScalarTensorOp::create(rewriter, loc, oneTensorType, constantOne, + intTypeValue, noneCst, noneCst, noneCst); Value strides; if (inputRank > 1) { @@ -6377,20 +6425,20 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern { auto slicedStrideType = Torch::ValueTensorType::get( rewriter.getContext(), SmallVector{inputRank - 1}, // sizes intType); - Value strideSliceEnd = rewriter.create( - loc, rewriter.getI64IntegerAttr(inputRank)); - Value slicedStrides = rewriter.create( - loc, slicedStrideType, /*self*/ flippedCumulativeProduct, + Value strideSliceEnd = ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(inputRank)); + Value slicedStrides = AtenSliceTensorOp::create( + rewriter, loc, slicedStrideType, /*self*/ flippedCumulativeProduct, /*dim*/ constantZero, /*start=*/constantOne, /*end=*/strideSliceEnd, /*step=*/constantOne); // torch.cat auto tensorListElementType = Torch::ValueTensorType::get( rewriter.getContext(), SmallVector{kUnknownSize}, intType); - Value tensorList = rewriter.create( - loc, Torch::ListType::get(tensorListElementType), + Value tensorList = Torch::PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(tensorListElementType), SmallVector{slicedStrides, oneTensor}); - strides = rewriter.create(loc, shapeType, tensorList, - constantZero); + strides = Torch::AtenCatOp::create(rewriter, loc, shapeType, tensorList, + constantZero); } else { // strides[1:-1] is empty strides = oneTensor; @@ -6400,22 +6448,23 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern { // input_shape_tensor auto unsqueezedResultType = ValueTensorType::get( rewriter.getContext(), SmallVector{kUnknownSize, 1}, intType); - Value unsqueezedResult = rewriter.create( - loc, unsqueezedResultType, slicedResult, constantOne); + Value unsqueezedResult = AtenUnsqueezeOp::create( + rewriter, loc, unsqueezedResultType, slicedResult, constantOne); auto unsqueezedStridesType = ValueTensorType::get( rewriter.getContext(), SmallVector{1, inputRank}, intType); - Value unsqueezedStrides = rewriter.create( - loc, unsqueezedStridesType, strides, constantZero); + Value unsqueezedStrides = AtenUnsqueezeOp::create( + rewriter, loc, unsqueezedStridesType, strides, constantZero); auto dividedBroadcastType = ValueTensorType::get( rewriter.getContext(), SmallVector{kUnknownSize, inputRank}, intType); - Value divided = rewriter.create( - loc, dividedBroadcastType, unsqueezedResult, unsqueezedStrides); + Value divided = + AtenFloorDivideOp::create(rewriter, loc, dividedBroadcastType, + unsqueezedResult, unsqueezedStrides); - Value modded = rewriter.create( - loc, resultType, divided, inputShapeTensor); + Value modded = AtenRemainderTensorOp::create(rewriter, loc, resultType, + divided, inputShapeTensor); rewriter.replaceOp(op, modded); return success(); @@ -6450,10 +6499,10 @@ class DecomposeAtenAddmmOp : public OpRewritePattern { } // matrix multiplication: matmul = mat1 @ mat2 - Value matmul = rewriter.create(loc, op.getType(), mat1, mat2); + Value matmul = AtenMmOp::create(rewriter, loc, op.getType(), mat1, mat2); // scaledInput = self * beta - Value scaledInput = rewriter.create(loc, input.getType(), - input, op.getBeta()); + Value scaledInput = AtenMulScalarOp::create(rewriter, loc, input.getType(), + input, op.getBeta()); // result = scaledInput + alpha * matmul rewriter.replaceOpWithNewOp(op, op.getType(), scaledInput, matmul, op.getAlpha()); @@ -6473,9 +6522,9 @@ class DecomposeAtenMeanOp : public OpRewritePattern { Value input = op.getSelf(); Value output = op.getResult(); BaseTensorType outputTensorType = cast(output.getType()); - Value sum = - rewriter.create(loc, outputTensorType, input, op.getDtype()); - Value numTensorElements = rewriter.create(loc, input); + Value sum = AtenSumOp::create(rewriter, loc, outputTensorType, input, + op.getDtype()); + Value numTensorElements = AtenNumelOp::create(rewriter, loc, input); rewriter.replaceOpWithNewOp(op, outputTensorType, sum, numTensorElements); return success(); @@ -6520,21 +6569,21 @@ class DecomposeAtenMeanDimOp : public OpRewritePattern { } // Compute sum along dimensions specified in `dimList`. - Value sumAlongDims = rewriter.create( - loc, outputType, input, dimList, keepDim, dtype); + Value sumAlongDims = AtenSumDimIntListOp::create( + rewriter, loc, outputType, input, dimList, keepDim, dtype); // `productDimSize` is product of sizes of dimensions to be reduced. Value productDimSize; // Case: Reduce along all dims. if (dimListElements.empty() && inputRank != 0) { - productDimSize = rewriter.create(loc, input); + productDimSize = AtenNumelOp::create(rewriter, loc, input); } else { - productDimSize = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); + productDimSize = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(1)); for (Value dim : dimListElements) { - Value dimSize = rewriter.create(loc, input, dim); + Value dimSize = AtenSizeIntOp::create(rewriter, loc, input, dim); productDimSize = - rewriter.create(loc, productDimSize, dimSize); + AtenMulIntOp::create(rewriter, loc, productDimSize, dimSize); } } rewriter.replaceOpWithNewOp(op, outputType, sumAlongDims, @@ -6575,8 +6624,8 @@ class DecomposeAtenStftCenterOp : public OpRewritePattern { Value hopLength = op.getHopLength(); if (isa(hopLength.getType())) { - hopLength = rewriter.create( - loc, n_fft, rewriter.create(loc, 4)); + hopLength = AtenFloordivIntOp::create( + rewriter, loc, n_fft, ConstantIntOp::create(rewriter, loc, 4)); } int64_t winLengthInt; @@ -6644,12 +6693,12 @@ class DecomposeAtenStftCenterOp : public OpRewritePattern { } if (windowTensorType.getSizes().back() == kUnknownSize) { Value actualWinLen = getTensorDimSize(rewriter, window, 0); - Value winSizeEq = rewriter.create( - loc, actualWinLen, - rewriter.create( - loc, rewriter.getI64IntegerAttr(winLengthInt))); - rewriter.create( - loc, winSizeEq, + Value winSizeEq = AtenEqIntOp::create( + rewriter, loc, actualWinLen, + ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(winLengthInt))); + RuntimeAssertOp::create( + rewriter, loc, winSizeEq, rewriter.getStringAttr( "window size should be equal to win_length")); } else if (windowTensorType.getSizes().back() != winLengthInt) { @@ -6658,32 +6707,32 @@ class DecomposeAtenStftCenterOp : public OpRewritePattern { } Value cstZero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0)); Value cstMinusOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(-1)); Value cstOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); - Value cstNone = rewriter.create(loc); - Value cstStrConstant = rewriter.create(loc, "constant"); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); + Value cstNone = ConstantNoneOp::create(rewriter, loc); + Value cstStrConstant = ConstantStrOp::create(rewriter, loc, "constant"); Value cstZeroFloat = - rewriter.create(loc, rewriter.getF64FloatAttr(0.0f)); - - Value signalLen = rewriter.create(loc, self, cstMinusOne); - Value nFrames = rewriter.create( - loc, cstOne, - rewriter.create( - loc, rewriter.create(loc, signalLen, n_fft), - hopLength)); + ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(0.0f)); + + Value signalLen = AtenSizeIntOp::create(rewriter, loc, self, cstMinusOne); + Value nFrames = AtenAddIntOp::create( + rewriter, loc, cstOne, + AtenFloordivIntOp::create( + rewriter, loc, + AtenSubIntOp::create(rewriter, loc, signalLen, n_fft), hopLength)); if (hasWindow && winLengthInt < n_fftInt) { int64_t totalPad = n_fftInt - winLengthInt; int64_t leftPad = totalPad / 2; int64_t rightPad = totalPad - leftPad; - Value p1d = rewriter.create( - loc, ListType::get(IntType::get(rewriter.getContext())), - ValueRange{rewriter.create( - loc, rewriter.getI64IntegerAttr(leftPad)), - rewriter.create( - loc, rewriter.getI64IntegerAttr(rightPad))}); + Value p1d = PrimListConstructOp::create( + rewriter, loc, ListType::get(IntType::get(rewriter.getContext())), + ValueRange{ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(leftPad)), + ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(rightPad))}); Type windowOutType = selfType.getWithSizesAndDtype( SmallVector({n_fftInt}), windowDType); Value constantValue; @@ -6691,8 +6740,8 @@ class DecomposeAtenStftCenterOp : public OpRewritePattern { constantValue = cstZero; else if (isa(windowDType)) constantValue = cstZeroFloat; - window = rewriter.create(loc, windowOutType, window, p1d, - cstStrConstant, constantValue); + window = AtenPadOp::create(rewriter, loc, windowOutType, window, p1d, + cstStrConstant, constantValue); } if (hasWindow && selfType.getSizes().size() == 2) { @@ -6700,32 +6749,32 @@ class DecomposeAtenStftCenterOp : public OpRewritePattern { Type windowReshapeTensorType = windowTensorType.getWithSizesAndDtype( newWindowSizes, windowTensorType.getOptionalDtype()); Value newWindowShape = toIntListConstruct(rewriter, loc, newWindowSizes); - window = rewriter.create(loc, windowReshapeTensorType, - window, newWindowShape); + window = AtenReshapeOp::create(rewriter, loc, windowReshapeTensorType, + window, newWindowShape); } - Value nFreqs = rewriter.create( - loc, rewriter.getI64IntegerAttr(n_fftInt / 2 + 1)); + Value nFreqs = ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(n_fftInt / 2 + 1)); SmallVector sizesValues = selfType.getSizes().size() == 2 ? SmallVector( {/*batch_size = */ getTensorDimSize(rewriter, self, 0), nFreqs, nFrames}) : SmallVector({nFreqs, nFrames}); - Value outputSizesList = rewriter.create( - loc, Torch::ListType::get(IntType::get(rewriter.getContext())), - sizesValues); + Value outputSizesList = Torch::PrimListConstructOp::create( + rewriter, loc, + Torch::ListType::get(IntType::get(rewriter.getContext())), sizesValues); Value resultDtype = getDtypeIntValueForType(rewriter, loc, resultType.getDtype()); - Value initFreqTensor = rewriter.create( - loc, resultType, outputSizesList, resultDtype, cstNone, cstNone, - cstNone, cstNone); + Value initFreqTensor = AtenEmptyMemoryFormatOp::create( + rewriter, loc, resultType, outputSizesList, resultDtype, cstNone, + cstNone, cstNone, cstNone); Value axisSignal = cstMinusOne; Value axisFrames = cstMinusOne; - Value loopCondTrue = rewriter.create(loc, true); + Value loopCondTrue = ConstantBoolOp::create(rewriter, loc, true); auto frameLoop = - rewriter.create(loc, TypeRange({resultType}), nFrames, - loopCondTrue, ValueRange({initFreqTensor})); + PrimLoopOp::create(rewriter, loc, TypeRange({resultType}), nFrames, + loopCondTrue, ValueRange({initFreqTensor})); { PatternRewriter::InsertionGuard guard(rewriter); Type loopIndexType = rewriter.getType(); @@ -6734,32 +6783,34 @@ class DecomposeAtenStftCenterOp : public OpRewritePattern { TypeRange({loopIndexType, resultType}), {loc, loc}); Value frame = countLoopBody->getArgument(0); Value freqTensor = countLoopBody->getArgument(1); - Value begin = rewriter.create(loc, frame, hopLength); - Value end = rewriter.create(loc, begin, n_fft); - Value narrowLen = rewriter.create( - loc, rewriter.create(loc, end, signalLen), begin); - Value missing = rewriter.create(loc, n_fft, narrowLen); + Value begin = AtenMulIntOp::create(rewriter, loc, frame, hopLength); + Value end = AtenAddIntOp::create(rewriter, loc, begin, n_fft); + Value narrowLen = AtenSubIntOp::create( + rewriter, loc, PrimMinIntOp::create(rewriter, loc, end, signalLen), + begin); + Value missing = AtenSubIntOp::create(rewriter, loc, n_fft, narrowLen); SmallVector slicedSizes(selfType.getSizes()); slicedSizes.back() = kUnknownSize; Type slicedTensorType = selfType.getWithSizesAndDtype( slicedSizes, selfType.getOptionalDtype()); - Value sliced = rewriter.create( - loc, slicedTensorType, self, axisSignal, begin, narrowLen); - Value padList = rewriter.create( - loc, ListType::get(IntType::get(rewriter.getContext())), + Value sliced = AtenNarrowOp::create(rewriter, loc, slicedTensorType, self, + axisSignal, begin, narrowLen); + Value padList = PrimListConstructOp::create( + rewriter, loc, ListType::get(IntType::get(rewriter.getContext())), ValueRange{cstZero, missing}); SmallVector paddedSlicedSizes(selfType.getSizes()); paddedSlicedSizes.back() = n_fftInt; Type paddedSlicedTensorType = selfType.getWithSizesAndDtype( paddedSlicedSizes, selfType.getOptionalDtype()); - Value paddedSliced = rewriter.create( - loc, paddedSlicedTensorType, - rewriter.create(loc, slicedTensorType, sliced, padList, - cstStrConstant, cstZeroFloat)); + Value paddedSliced = TensorStaticInfoCastOp::create( + rewriter, loc, paddedSlicedTensorType, + AtenPadOp::create(rewriter, loc, slicedTensorType, sliced, padList, + cstStrConstant, cstZeroFloat)); Value weighted = - hasWindow ? rewriter.create( - loc, paddedSlicedTensorType, paddedSliced, window) - : paddedSliced; + hasWindow + ? AtenMulTensorOp::create(rewriter, loc, paddedSlicedTensorType, + paddedSliced, window) + : paddedSliced; int64_t freqsDimSize = returnComplexBool ? resultSizes[resultSizes.size() - 2] : resultSizes[resultSizes.size() - 3]; @@ -6769,23 +6820,23 @@ class DecomposeAtenStftCenterOp : public OpRewritePattern { fftSizes, resultType.getOptionalDtype()); Value freqSliceSq; if (onesidedBool) { - freqSliceSq = rewriter.create( - loc, fftType, weighted, cstNone, axisSignal, cstNone); + freqSliceSq = AtenFftRfftOp::create(rewriter, loc, fftType, weighted, + cstNone, axisSignal, cstNone); } else { - freqSliceSq = rewriter.create( - loc, fftType, weighted, cstNone, axisSignal, cstNone); + freqSliceSq = AtenFftFftOp::create(rewriter, loc, fftType, weighted, + cstNone, axisSignal, cstNone); } SmallVector freqSliceSizes(fftSizes); freqSliceSizes.push_back(1); Type freqSliceType = resultType.getWithSizesAndDtype( freqSliceSizes, resultType.getOptionalDtype()); - Value freqSlice = rewriter.create( - loc, freqSliceType, freqSliceSq, axisFrames); - Value newFreqTensor = rewriter.create( - loc, resultType, freqTensor, freqSlice, /*dim=*/axisFrames, + Value freqSlice = AtenUnsqueezeOp::create(rewriter, loc, freqSliceType, + freqSliceSq, axisFrames); + Value newFreqTensor = AtenSliceScatterOp::create( + rewriter, loc, resultType, freqTensor, freqSlice, /*dim=*/axisFrames, /*start=*/frame, /*end=*/cstNone, /*step=*/cstOne); - rewriter.create(loc, loopCondTrue, - ValueRange({newFreqTensor})); + PrimLoopConditionOp::create(rewriter, loc, loopCondTrue, + ValueRange({newFreqTensor})); } rewriter.replaceOp(op, frameLoop.getResults()); @@ -6816,7 +6867,7 @@ class DecomposeAtenSiluOp : public OpRewritePattern { PatternRewriter &rewriter) const override { Value self = op.getSelf(); Value sigmoid = - rewriter.create(op.getLoc(), op.getType(), self); + AtenSigmoidOp::create(rewriter, op.getLoc(), op.getType(), self); rewriter.replaceOpWithNewOp(op, op.getType(), sigmoid, self); return success(); @@ -6849,14 +6900,14 @@ class DecomposeAtenDropoutOp : public OpRewritePattern { if (!inputType.hasDtype() || !isa(inputType.getDtype())) return rewriter.notifyMatchFailure( op, "only support floating type input for training mode"); - Value noneVal = rewriter.create(loc); + Value noneVal = ConstantNoneOp::create(rewriter, loc); Value floatOne = - rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); - Value oneMinusP = rewriter.create(loc, floatOne, prob); - Value boolMask = rewriter.create( - loc, inputType, input, oneMinusP, /*generator=*/noneVal); + ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(1.0)); + Value oneMinusP = AtenSubFloatOp::create(rewriter, loc, floatOne, prob); + Value boolMask = ValsemVariantAtenBernoulliFloatOp::create( + rewriter, loc, inputType, input, oneMinusP, /*generator=*/noneVal); Value maskedInput = - rewriter.create(loc, inputType, boolMask, input); + AtenMulTensorOp::create(rewriter, loc, inputType, boolMask, input); rewriter.replaceOpWithNewOp(op, op.getType(), maskedInput, oneMinusP); return success(); @@ -6880,15 +6931,16 @@ class DeomposeAtenNativeDropoutOp op, "train must be a boolean constant or none"); } } - Value noneVal = rewriter.create(loc); + Value noneVal = ConstantNoneOp::create(rewriter, loc); if (!train) { Value i1Type = getDtypeIntValueForType(rewriter, loc, IntegerType::get(context, 1)); - Value inputSize = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(context)), input); - Value trueValue = rewriter.create(loc, 1); - Value trueMask = rewriter.create( - loc, op->getResultTypes()[1], inputSize, trueValue, i1Type, + Value inputSize = AtenSizeOp::create( + rewriter, loc, Torch::ListType::get(Torch::IntType::get(context)), + input); + Value trueValue = ConstantIntOp::create(rewriter, loc, 1); + Value trueMask = AtenFullOp::create( + rewriter, loc, op->getResultTypes()[1], inputSize, trueValue, i1Type, /*layout=*/noneVal, /*device=*/noneVal, /*pin_memory=*/noneVal); rewriter.replaceOp(op, ArrayRef{input, trueMask}); return success(); @@ -6899,14 +6951,14 @@ class DeomposeAtenNativeDropoutOp op, "only support floating type input for training mode"); } Value floatOne = - rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); - Value oneMinusP = rewriter.create(loc, floatOne, prob); - Value boolMask = rewriter.create( - loc, inputType, input, oneMinusP, /*generator=*/noneVal); + ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(1.0)); + Value oneMinusP = AtenSubFloatOp::create(rewriter, loc, floatOne, prob); + Value boolMask = ValsemVariantAtenBernoulliFloatOp::create( + rewriter, loc, inputType, input, oneMinusP, /*generator=*/noneVal); Value maskedInput = - rewriter.create(loc, inputType, boolMask, input); - Value output = rewriter.create( - loc, op->getResultTypes()[0], maskedInput, oneMinusP); + AtenMulTensorOp::create(rewriter, loc, inputType, boolMask, input); + Value output = AtenDivScalarOp::create( + rewriter, loc, op->getResultTypes()[0], maskedInput, oneMinusP); rewriter.replaceOp( op, ArrayRef{ output, convertTensorToDtype(rewriter, loc, boolMask, @@ -6939,12 +6991,14 @@ class DecomposeAtenVarOp : public OpRewritePattern { SmallVector dims; for (unsigned i = 0; i < inputRank; i++) - dims.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(i))); - Value dimList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), dims); + dims.push_back(Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(i))); + Value dimList = PrimListConstructOp::create( + rewriter, loc, + Torch::ListType::get(Torch::IntType::get(op.getContext())), dims); - Value cstFalse = rewriter.create(op.getLoc(), false); + Value cstFalse = + Torch::ConstantBoolOp::create(rewriter, op.getLoc(), false); rewriter.replaceOpWithNewOp(op, rank0FloatTensorTy, self, dimList, op.getUnbiased(), /*keepdim=*/cstFalse); @@ -6967,8 +7021,8 @@ class DecomposeAtenStdOp : public OpRewritePattern { return rewriter.notifyMatchFailure(op, "Only aten.std support floating type"); } - Value var = rewriter.create(op->getLoc(), op.getType(), - op.getSelf(), op.getUnbiased()); + Value var = AtenVarOp::create(rewriter, op->getLoc(), op.getType(), + op.getSelf(), op.getUnbiased()); rewriter.replaceOpWithNewOp(op, op.getType(), var); return success(); } @@ -6988,19 +7042,19 @@ class DecomposeAtenSoftplusOp : public OpRewritePattern { BaseTensorType inputType = cast(input.getType()); Value inputTimesBeta = - rewriter.create(loc, inputType, input, op.getBeta()); + AtenMulScalarOp::create(rewriter, loc, inputType, input, op.getBeta()); // out = log1p(exp(input * beta)) / beta - Value exp = rewriter.create(loc, inputType, inputTimesBeta); - Value log1p = rewriter.create(loc, inputType, exp); + Value exp = AtenExpOp::create(rewriter, loc, inputType, inputTimesBeta); + Value log1p = AtenLog1pOp::create(rewriter, loc, inputType, exp); Value out = - rewriter.create(loc, inputType, log1p, op.getBeta()); + AtenDivScalarOp::create(rewriter, loc, inputType, log1p, op.getBeta()); // Select where x * beta > threshold auto boolResType = inputType.getWithSizesAndDtype(inputType.getSizes(), rewriter.getI1Type()); - Value condition = rewriter.create( - loc, boolResType, inputTimesBeta, op.getThreshold()); + Value condition = AtenGtScalarOp::create(rewriter, loc, boolResType, + inputTimesBeta, op.getThreshold()); rewriter.replaceOpWithNewOp(op, op.getType(), condition, input, out); @@ -7024,9 +7078,9 @@ class DecomposeAtenStdDimOp : public OpRewritePattern { op, "aten.std.dim expects input tensor of floating-point type"); } - Value varDim = rewriter.create( - op->getLoc(), op.getType(), self, op.getDim(), op.getUnbiased(), - op.getKeepdim()); + Value varDim = + AtenVarDimOp::create(rewriter, op->getLoc(), op.getType(), self, + op.getDim(), op.getUnbiased(), op.getKeepdim()); rewriter.replaceOpWithNewOp(op, op.getType(), varDim); return success(); } @@ -7062,13 +7116,13 @@ class DecomposeAtenRot90Op : public OpRewritePattern { // have different implementation for operand %. if (k == 1) { - Value flipDimList = rewriter.create( - loc, + Value flipDimList = Torch::PrimListConstructOp::create( + rewriter, loc, rewriter.getType(rewriter.getType()), ArrayRef{dims[1]}); Value flip = - rewriter.create(loc, self.getType(), self, flipDimList); + AtenFlipOp::create(rewriter, loc, self.getType(), self, flipDimList); rewriter.replaceOpWithNewOp( op, op.getType(), flip, dims[0], dims[1]); @@ -7076,13 +7130,13 @@ class DecomposeAtenRot90Op : public OpRewritePattern { rewriter.replaceOpWithNewOp(op, op.getType(), self, op.getDims()); } else if (k == 3) { - Value flipDimList = rewriter.create( - loc, + Value flipDimList = Torch::PrimListConstructOp::create( + rewriter, loc, rewriter.getType(rewriter.getType()), ArrayRef{dims[0]}); Value flip = - rewriter.create(loc, self.getType(), self, flipDimList); + AtenFlipOp::create(rewriter, loc, self.getType(), self, flipDimList); rewriter.replaceOpWithNewOp( op, op.getType(), flip, dims[0], dims[1]); @@ -7090,8 +7144,8 @@ class DecomposeAtenRot90Op : public OpRewritePattern { rewriter.replaceOpWithNewOp( op, op.getType(), self, /*memory_format=*/ - rewriter.create(loc, - rewriter.getI64IntegerAttr(0))); + Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(0))); } return success(); @@ -7119,17 +7173,17 @@ class DecomposeAtenCountNonzeroOp auto inpBoolTy = inputType.getWithSizesAndDtype(inputType.getSizes(), rewriter.getI1Type()); Value cstZero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0)); Value nonZeroMask = - rewriter.create(loc, inpBoolTy, self, cstZero); - Value none = rewriter.create(loc); + AtenNeScalarOp::create(rewriter, loc, inpBoolTy, self, cstZero); + Value none = ConstantNoneOp::create(rewriter, loc); if (isa(dim.getType())) { rewriter.replaceOpWithNewOp(op, op.getResult().getType(), nonZeroMask, none); } else { - Value cstFalse = rewriter.create(loc, false); - Value dimIntList = rewriter.create( - loc, ListType::get(IntType::get(op.getContext())), + Value cstFalse = ConstantBoolOp::create(rewriter, loc, false); + Value dimIntList = PrimListConstructOp::create( + rewriter, loc, ListType::get(IntType::get(op.getContext())), SmallVector{dim}); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), nonZeroMask, dimIntList, cstFalse, @@ -7164,11 +7218,11 @@ class DecomposeAtenCountNonzeroDimIntListOp auto inpBoolTy = inputType.getWithSizesAndDtype(inputType.getSizes(), rewriter.getI1Type()); Value cstZero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0)); Value nonZeroMask = - rewriter.create(loc, inpBoolTy, self, cstZero); - Value none = rewriter.create(loc); - Value cstFalse = rewriter.create(loc, false); + AtenNeScalarOp::create(rewriter, loc, inpBoolTy, self, cstZero); + Value none = ConstantNoneOp::create(rewriter, loc); + Value cstFalse = ConstantBoolOp::create(rewriter, loc, false); rewriter.replaceOpWithNewOp( op, op.getResult().getType(), nonZeroMask, dimList, cstFalse, none); return success(); @@ -7193,9 +7247,9 @@ class DecomposeAtenStdCorrectionOp "aten.std.correction expects input tensor of floating-point type"); } - Value varCorrection = rewriter.create( - op->getLoc(), op.getType(), self, op.getDim(), op.getCorrection(), - op.getKeepdim()); + Value varCorrection = AtenVarCorrectionOp::create( + rewriter, op->getLoc(), op.getType(), self, op.getDim(), + op.getCorrection(), op.getKeepdim()); rewriter.replaceOpWithNewOp(op, op.getType(), varCorrection); return success(); } @@ -7218,23 +7272,23 @@ class DecomposeAtenHardsigmoidOp : public OpRewritePattern { } // outputTensor = (input + 3) / 6. - Value constantOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - Value constantThree = rewriter.create( - loc, rewriter.getI64IntegerAttr(3)); - Value constantSix = rewriter.create( - loc, rewriter.getI64IntegerAttr(6)); - Value inputPlusThree = rewriter.create( - loc, inputType, input, constantThree, /*alpha=*/constantOne); - Value outputTensor = rewriter.create( - loc, inputType, inputPlusThree, constantSix); + Value constantOne = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(1)); + Value constantThree = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(3)); + Value constantSix = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(6)); + Value inputPlusThree = AtenAddScalarOp::create( + rewriter, loc, inputType, input, constantThree, /*alpha=*/constantOne); + Value outputTensor = AtenDivScalarOp::create(rewriter, loc, inputType, + inputPlusThree, constantSix); // result = max(0, min(1, (input+3)/6)) - Value constantZero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); + Value constantZero = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(0)); Value oneTensor = createRank0Tensor(rewriter, loc, inputType, constantOne); - Value minResult = - rewriter.create(loc, inputType, oneTensor, outputTensor); + Value minResult = AtenMinimumOp::create(rewriter, loc, inputType, oneTensor, + outputTensor); Value zeroTensor = createRank0Tensor(rewriter, loc, inputType, constantZero); rewriter.replaceOpWithNewOp(op, op.getType(), zeroTensor, @@ -7261,7 +7315,7 @@ class DecomposeAtenHardtanhOp : public OpRewritePattern { // result = min(maxVal, max(minVal, x)) Value minVal = createRank0Tensor(rewriter, loc, inputType, op.getMinVal()); Value maxResult = - rewriter.create(loc, inputType, input, minVal); + AtenMaximumOp::create(rewriter, loc, inputType, input, minVal); Value maxVal = createRank0Tensor(rewriter, loc, inputType, op.getMaxVal()); rewriter.replaceOpWithNewOp(op, op.getType(), maxVal, maxResult); @@ -7287,13 +7341,13 @@ class DecomposeAtenRandLikeOp : public OpRewritePattern { // Create a uniform random op with low and high set to 0.0 and 1.0, // respectively. - Value none = rewriter.create(loc); + Value none = ConstantNoneOp::create(rewriter, loc); Value zero = - rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); + ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(0.0)); Value one = - rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); - Value emptyTensor = rewriter.create( - loc, resultType, input, zero, op.getDtype(), op.getLayout(), + ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(1.0)); + Value emptyTensor = AtenFullLikeOp::create( + rewriter, loc, resultType, input, zero, op.getDtype(), op.getLayout(), op.getDevice(), op.getPinMemory(), op.getMemoryFormat()); rewriter.replaceOpWithNewOp(op, resultType, emptyTensor, /*from=*/zero, /*to=*/one, @@ -7330,16 +7384,17 @@ static LogicalResult decomposeBernoulliLikeOp(PatternRewriter &rewriter, // float-type tensor with the same shape as that of the `input`. Value floatTensor = convertTensorToDtype(rewriter, loc, input, rewriter.getF64Type()); - Value none = rewriter.create(loc); - Value randomVal = rewriter.create( - loc, floatTensor.getType(), floatTensor, /*dtype=*/none, /*layout=*/none, + Value none = ConstantNoneOp::create(rewriter, loc); + Value randomVal = AtenRandLikeOp::create( + rewriter, loc, floatTensor.getType(), floatTensor, /*dtype=*/none, + /*layout=*/none, /*device=*/none, /*pinMemory=*/none, /*memoryFormat=*/none); // Bernoulli(x, p) = randLike(float(x)) < p. auto boolResType = inputType.getWithSizesAndDtype(inputType.getSizes(), rewriter.getI1Type()); Value lessThanP = - rewriter.create(loc, boolResType, randomVal, prob); + AtenLtTensorOp::create(rewriter, loc, boolResType, randomVal, prob); // As the `output` is expected to be of the `input` type, convert the boolean // tensor `lessThanP` to a `input` type tensor. @@ -7391,7 +7446,7 @@ class DecomposeAtenBernoulliLikeOp : public OpRewritePattern { SmallVector empty; Type tensorType = inputType.getWithSizesAndDtype(llvm::ArrayRef(empty), rewriter.getF64Type()); - Value prob = rewriter.create(loc, tensorType, p); + Value prob = PrimNumToTensorScalarOp::create(rewriter, loc, tensorType, p); Value output; if (failed( decomposeBernoulliLikeOp(rewriter, op, loc, input, prob, output))) @@ -7449,24 +7504,25 @@ class DecomposeAtenExponentialOp : public OpRewritePattern { // Create a uniform random op with low and high set to 0.0 and 1.0, // respectively. - Value none = rewriter.create(loc); + Value none = ConstantNoneOp::create(rewriter, loc); Value zero = - rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); + ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(0.0)); Value one = - rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); - Value emptyTensor = rewriter.create( - loc, resultType, op.getSelf(), zero, /*dtype=*/none, /*layout=*/none, + ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(1.0)); + Value emptyTensor = AtenFullLikeOp::create( + rewriter, loc, resultType, op.getSelf(), zero, /*dtype=*/none, + /*layout=*/none, /*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none); - Value x = rewriter.create(loc, resultType, emptyTensor, - /*from=*/zero, /*to=*/one, - /*generator=*/none); + Value x = AtenUniformOp::create(rewriter, loc, resultType, emptyTensor, + /*from=*/zero, /*to=*/one, + /*generator=*/none); - Value negX = rewriter.create(loc, resultType, x); + Value negX = AtenNegOp::create(rewriter, loc, resultType, x); Value oneMinusX = - rewriter.create(loc, resultType, negX, one, - /*alpha=*/one); - Value lnOneMinusX = rewriter.create(loc, resultType, oneMinusX); - Value negLambda = rewriter.create(loc, op.getLambd()); + AtenAddScalarOp::create(rewriter, loc, resultType, negX, one, + /*alpha=*/one); + Value lnOneMinusX = AtenLogOp::create(rewriter, loc, resultType, oneMinusX); + Value negLambda = AtenNegFloatOp::create(rewriter, loc, op.getLambd()); rewriter.replaceOpWithNewOp(op, resultType, lnOneMinusX, negLambda); return success(); @@ -7490,14 +7546,15 @@ class DecomposeAtenNormalFunctionalOp Value std = op.getStd(); Value mean = op.getMean(); - Value none = rewriter.create(loc); + Value none = ConstantNoneOp::create(rewriter, loc); Value one = - rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); - Value randN = rewriter.create( - loc, resultType, op.getSelf(), /*dtype=*/none, /*layout=*/none, + ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(1.0)); + Value randN = AtenRandnLikeOp::create( + rewriter, loc, resultType, op.getSelf(), /*dtype=*/none, + /*layout=*/none, /*device=*/none, /*pin_memory=*/none, /*memory_format=*/none); Value stdRandN = - rewriter.create(loc, resultType, randN, std); + AtenMulScalarOp::create(rewriter, loc, resultType, randN, std); rewriter.replaceOpWithNewOp(op, resultType, stdRandN, mean, /*alpha=*/one); return success(); @@ -7516,7 +7573,7 @@ class DecomposeAtenAddCLikeOp : public OpRewritePattern { Value value = op.getValue(); Value product = - rewriter.create(loc, op.getType(), tensor1, tensor2); + T1T2Op::create(rewriter, loc, op.getType(), tensor1, tensor2); rewriter.replaceOpWithNewOp(op, op.getType(), input, product, value); return success(); @@ -7543,8 +7600,8 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern { meanVarSizes[i] = input.getSizes()[i]; auto meanVarType = input.getWithSizesAndDtype(llvm::ArrayRef(meanVarSizes), input.getOptionalDtype()); - auto nativeLayerNorm = rewriter.create( - loc, op.getType(), meanVarType, meanVarType, op.getInput(), + auto nativeLayerNorm = AtenNativeLayerNormOp::create( + rewriter, loc, op.getType(), meanVarType, meanVarType, op.getInput(), op.getNormalizedShape(), op.getWeight(), op.getBias(), op.getEps()); rewriter.replaceOp(op, nativeLayerNorm.getResult(0)); return success(); @@ -7569,8 +7626,8 @@ class DecomposeAtenInstanceNormOp SmallVector reduceDimVals; for (int i = 2; i < inputRank; ++i) { reducedShape[i] = 1; - reduceDimVals.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(i))); + reduceDimVals.push_back(Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(i))); } Type dtype = inputTy.getOptionalDtype(); @@ -7579,52 +7636,52 @@ class DecomposeAtenInstanceNormOp auto sizeListType = ListType::get(IntType::get(context)); Value reduceDimList = - rewriter.create(loc, sizeListType, reduceDimVals); - Value cstTrue = rewriter.create(loc, true); - Value none = rewriter.create(loc); + PrimListConstructOp::create(rewriter, loc, sizeListType, reduceDimVals); + Value cstTrue = Torch::ConstantBoolOp::create(rewriter, loc, true); + Value none = Torch::ConstantNoneOp::create(rewriter, loc); - Value one = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); + Value one = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(1)); // mean(x) - Value inputMean = rewriter.create( - loc, reducedTy, op.getInput(), reduceDimList, cstTrue, none); + Value inputMean = AtenMeanDimOp::create( + rewriter, loc, reducedTy, op.getInput(), reduceDimList, cstTrue, none); // x - mean(x) - Value inputMeanExpanded = - rewriter.create(loc, inputTy, inputMean, op.getInput()); - Value inputSubMean = rewriter.create( - loc, inputTy, op.getInput(), inputMeanExpanded, one); + Value inputMeanExpanded = AtenExpandAsOp::create(rewriter, loc, inputTy, + inputMean, op.getInput()); + Value inputSubMean = AtenSubTensorOp::create( + rewriter, loc, inputTy, op.getInput(), inputMeanExpanded, one); // (x - mean(x))^2 - Value inputSubMeanSquare = rewriter.create( - loc, inputTy, inputSubMean, inputSubMean); + Value inputSubMeanSquare = AtenMulTensorOp::create( + rewriter, loc, inputTy, inputSubMean, inputSubMean); - Value variancesum = rewriter.create( - loc, reducedTy, inputSubMeanSquare, reduceDimList, cstTrue, + Value variancesum = AtenSumDimIntListOp::create( + rewriter, loc, reducedTy, inputSubMeanSquare, reduceDimList, cstTrue, /*dtype=*/none); int64_t elemCount = 1; for (int i = 2; i < inputRank; ++i) elemCount *= inputTy.getSizes()[i]; - Value hw = rewriter.create( - loc, rewriter.getI64IntegerAttr(elemCount)); + Value hw = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(elemCount)); Value inputVar = - rewriter.create(loc, reducedTy, variancesum, hw); + AtenDivScalarOp::create(rewriter, loc, reducedTy, variancesum, hw); // rsqrt(var(x) + eps) - Value inputVarPlusEps = rewriter.create( - loc, reducedTy, inputVar, op.getEps(), one); + Value inputVarPlusEps = AtenAddScalarOp::create(rewriter, loc, reducedTy, + inputVar, op.getEps(), one); Value inputRsqrtVar = - rewriter.create(loc, reducedTy, inputVarPlusEps); + AtenRsqrtOp::create(rewriter, loc, reducedTy, inputVarPlusEps); // (x - mean(x)) * rsqrt(var(x) + eps) - Value inputRsqrtVarExpanded = rewriter.create( - loc, inputTy, inputRsqrtVar, op.getInput()); - Value inputNormalized = rewriter.create( - loc, inputTy, inputSubMean, inputRsqrtVarExpanded); - Value out = rewriter.create( - loc, op.getResult().getType(), inputNormalized); + Value inputRsqrtVarExpanded = AtenExpandAsOp::create( + rewriter, loc, inputTy, inputRsqrtVar, op.getInput()); + Value inputNormalized = AtenMulTensorOp::create( + rewriter, loc, inputTy, inputSubMean, inputRsqrtVarExpanded); + Value out = TensorStaticInfoCastOp::create( + rewriter, loc, op.getResult().getType(), inputNormalized); Value weight = op.getWeight(); auto weightTy = cast(weight.getType()); @@ -7635,23 +7692,23 @@ class DecomposeAtenInstanceNormOp newWeightShape.push_back(1); newWeightShape.append(weightShape); - Value zero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); + Value zero = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(0)); Type newWeightTy = ValueTensorType::get( op.getContext(), llvm::ArrayRef(newWeightShape), dtype); - weight = rewriter.create(loc, newWeightTy, weight, zero); + weight = AtenUnsqueezeOp::create(rewriter, loc, newWeightTy, weight, zero); while (static_cast(newWeightShape.size()) < inputRank) { - Value i = rewriter.create( - loc, rewriter.getI64IntegerAttr(newWeightShape.size())); + Value i = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(newWeightShape.size())); newWeightShape.push_back(1); newWeightTy = ValueTensorType::get(op.getContext(), llvm::ArrayRef(newWeightShape), dtype); - weight = rewriter.create(loc, newWeightTy, weight, i); + weight = AtenUnsqueezeOp::create(rewriter, loc, newWeightTy, weight, i); } Value weightExpanded = - rewriter.create(loc, inputTy, weight, op.getInput()); + AtenExpandAsOp::create(rewriter, loc, inputTy, weight, op.getInput()); Value bias = op.getBias(); auto biasTy = cast(bias.getType()); @@ -7664,24 +7721,24 @@ class DecomposeAtenInstanceNormOp Type newBiasTy = ValueTensorType::get(op.getContext(), llvm::ArrayRef(newBiasShape), dtype); - bias = rewriter.create(loc, newBiasTy, bias, zero); + bias = AtenUnsqueezeOp::create(rewriter, loc, newBiasTy, bias, zero); while (static_cast(newBiasShape.size()) < inputRank) { - Value i = rewriter.create( - loc, rewriter.getI64IntegerAttr(newBiasShape.size())); + Value i = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(newBiasShape.size())); newBiasShape.push_back(1); newBiasTy = ValueTensorType::get(op.getContext(), llvm::ArrayRef(newBiasShape), dtype); - bias = rewriter.create(loc, newBiasTy, bias, i); + bias = AtenUnsqueezeOp::create(rewriter, loc, newBiasTy, bias, i); } Value biasExpanded = - rewriter.create(loc, inputTy, bias, op.getInput()); + AtenExpandAsOp::create(rewriter, loc, inputTy, bias, op.getInput()); - out = rewriter.create(loc, out.getType(), out, - weightExpanded); - out = rewriter.create(loc, out.getType(), out, - biasExpanded, one); + out = AtenMulTensorOp::create(rewriter, loc, out.getType(), out, + weightExpanded); + out = AtenAddTensorOp::create(rewriter, loc, out.getType(), out, + biasExpanded, one); rewriter.replaceOp(op, out); return success(); @@ -7713,32 +7770,32 @@ class DecomposeAten_WeightNormInterfaceOp for (int64_t i = 0; i < static_cast(sizes.size()); ++i) { if (i != static_cast(dim.getDefiningOp().getValue())) - keepDims.push_back( - rewriter.create(loc, rewriter.getI64IntegerAttr(i))); + keepDims.push_back(ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(i))); } Value ord = - rewriter.create(loc, rewriter.getI64IntegerAttr(2)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(2)); Value keepdim = - rewriter.create(loc, rewriter.getBoolAttr(true)); - Value dtypeNone = rewriter.create(loc); + ConstantBoolOp::create(rewriter, loc, rewriter.getBoolAttr(true)); + Value dtypeNone = ConstantNoneOp::create(rewriter, loc); - Value dimList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(op->getContext())), - keepDims); + Value dimList = PrimListConstructOp::create( + rewriter, loc, + Torch::ListType::get(Torch::IntType::get(op->getContext())), keepDims); - Value norm = rewriter.create( - loc, v.getType(), v, ord, dimList, keepdim, dtypeNone); + Value norm = AtenLinalgVectorNormOp::create( + rewriter, loc, v.getType(), v, ord, dimList, keepdim, dtypeNone); - auto vShape = rewriter.create( - loc, Torch::ListType::get(rewriter.getI64Type()), v); + auto vShape = AtenSizeOp::create( + rewriter, loc, Torch::ListType::get(rewriter.getI64Type()), v); Value gDivNorm = - rewriter.create(loc, g.getType(), g, norm); + AtenDivTensorOp::create(rewriter, loc, g.getType(), g, norm); Value broadcastedGDivNorm = - rewriter.create(loc, v.getType(), gDivNorm, vShape); - Value vMulBroadcastedGDivNorm = rewriter.create( - loc, v.getType(), v, broadcastedGDivNorm); + AtenBroadcastToOp::create(rewriter, loc, v.getType(), gDivNorm, vShape); + Value vMulBroadcastedGDivNorm = AtenMulTensorOp::create( + rewriter, loc, v.getType(), v, broadcastedGDivNorm); rewriter.replaceOp(op, ArrayRef{vMulBroadcastedGDivNorm, norm}); return success(); @@ -7774,69 +7831,70 @@ class DecomposeAtenNativeLayerNormOp reduceDimVals.reserve(reduceDimInts.size()); std::transform(reduceDimInts.begin(), reduceDimInts.end(), std::back_inserter(reduceDimVals), [&](int64_t d) { - return rewriter.create( - loc, rewriter.getI64IntegerAttr(d)); + return Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(d)); }); Value reduceDimList = - rewriter.create(loc, sizeListType, reduceDimVals); - Value one = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); + PrimListConstructOp::create(rewriter, loc, sizeListType, reduceDimVals); + Value one = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(1)); - Value cstTrue = rewriter.create(loc, true); - Value none = rewriter.create(loc); + Value cstTrue = Torch::ConstantBoolOp::create(rewriter, loc, true); + Value none = Torch::ConstantNoneOp::create(rewriter, loc); // mean(x) - Value inputMean = rewriter.create( - loc, reducedTy, op.getInput(), reduceDimList, cstTrue, none); + Value inputMean = AtenMeanDimOp::create( + rewriter, loc, reducedTy, op.getInput(), reduceDimList, cstTrue, none); Value inputMeanCasted = convertTensorToDtype(rewriter, loc, inputMean, inputTy.getDtype()); // x - mean(x) - Value inputMeanExpanded = rewriter.create( - loc, inputTy, inputMeanCasted, op.getInput()); - Value inputZeroMean = rewriter.create( - loc, inputTy, op.getInput(), inputMeanExpanded, one); + Value inputMeanExpanded = AtenExpandAsOp::create( + rewriter, loc, inputTy, inputMeanCasted, op.getInput()); + Value inputZeroMean = AtenSubTensorOp::create( + rewriter, loc, inputTy, op.getInput(), inputMeanExpanded, one); // var(x) = mean((x - mean(x))^2) - Value inputZeroMeanSquare = rewriter.create( - loc, inputTy, inputZeroMean, inputZeroMean); - Value inputVar = rewriter.create( - loc, reducedTy, inputZeroMeanSquare, reduceDimList, cstTrue, none); + Value inputZeroMeanSquare = AtenMulTensorOp::create( + rewriter, loc, inputTy, inputZeroMean, inputZeroMean); + Value inputVar = + AtenMeanDimOp::create(rewriter, loc, reducedTy, inputZeroMeanSquare, + reduceDimList, cstTrue, none); // rsqrt(var(x) + eps) - Value inputVarPlusEps = rewriter.create( - loc, reducedTy, inputVar, op.getEps(), one); + Value inputVarPlusEps = AtenAddScalarOp::create(rewriter, loc, reducedTy, + inputVar, op.getEps(), one); Value inputRsqrtVar = - rewriter.create(loc, reducedTy, inputVarPlusEps); + AtenRsqrtOp::create(rewriter, loc, reducedTy, inputVarPlusEps); Value inputRsqrtVarCasted = convertTensorToDtype(rewriter, loc, inputRsqrtVar, inputTy.getDtype()); // (x - mean(x)) * rsqrt(var(x) + eps) - Value inputRsqrtVarExpanded = rewriter.create( - loc, inputTy, inputRsqrtVarCasted, op.getInput()); - Value inputNormalized = rewriter.create( - loc, inputTy, inputZeroMean, inputRsqrtVarExpanded); + Value inputRsqrtVarExpanded = AtenExpandAsOp::create( + rewriter, loc, inputTy, inputRsqrtVarCasted, op.getInput()); + Value inputNormalized = AtenMulTensorOp::create( + rewriter, loc, inputTy, inputZeroMean, inputRsqrtVarExpanded); // Convert resultType if dtype is different auto resultTensorType = dyn_cast(op.getResult(0).getType()); if (inputTy.getDtype() != resultTensorType.getDtype()) { Value dtypeValue = Torch::getDtypeIntValueForType( rewriter, loc, resultTensorType.getDtype()); - Value cstFalse = rewriter.create(loc, false); - inputNormalized = rewriter.create( - loc, resultTensorType, inputNormalized, + Value cstFalse = Torch::ConstantBoolOp::create(rewriter, loc, false); + inputNormalized = Torch::AtenToDtypeOp::create( + rewriter, loc, resultTensorType, inputNormalized, /*dtype=*/dtypeValue, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); } - Value out = rewriter.create( - loc, op.getResult(0).getType(), inputNormalized); + Value out = TensorStaticInfoCastOp::create( + rewriter, loc, op.getResult(0).getType(), inputNormalized); Value weight = op.getWeight(); Value bias = op.getBias(); if (!isa(weight.getType())) { - out = rewriter.create(loc, out.getType(), out, weight); + out = AtenMulTensorOp::create(rewriter, loc, out.getType(), out, weight); } if (!isa(bias.getType())) { out = - rewriter.create(loc, out.getType(), out, bias, one); + AtenAddTensorOp::create(rewriter, loc, out.getType(), out, bias, one); } rewriter.replaceOp(op, {out, inputMean, inputRsqrtVar}); @@ -7882,10 +7940,10 @@ class DecomposeAtenRMSLayerNormOp : public OpRewritePattern { SmallVector reduceDimVals; for (int64_t dim : reduceDimInts) - reduceDimVals.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(dim))); + reduceDimVals.push_back(Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(dim))); Value reduceDimList = - rewriter.create(loc, sizeListType, reduceDimVals); + PrimListConstructOp::create(rewriter, loc, sizeListType, reduceDimVals); auto inputShape = inputTy.getSizes(); SmallVector reducedShape(inputShape.begin(), inputShape.end()); @@ -7894,29 +7952,29 @@ class DecomposeAtenRMSLayerNormOp : public OpRewritePattern { auto reducedTy = ValueTensorType::get(context, reducedShape, inputTy.getDtype()); // x^2 - Value inputSquared = rewriter.create(loc, inputTy, input); - Value cstTrue = rewriter.create(loc, true); - Value none = rewriter.create(loc); + Value inputSquared = AtenSquareOp::create(rewriter, loc, inputTy, input); + Value cstTrue = Torch::ConstantBoolOp::create(rewriter, loc, true); + Value none = Torch::ConstantNoneOp::create(rewriter, loc); // mean(x^2) - Value mean = rewriter.create(loc, reducedTy, inputSquared, - reduceDimList, cstTrue, none); + Value mean = AtenMeanDimOp::create(rewriter, loc, reducedTy, inputSquared, + reduceDimList, cstTrue, none); // mean(x^2) + eps: Add eps if provided if (!isa(op.getEps().getType())) { - Value one = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - mean = rewriter.create(loc, reducedTy, mean, op.getEps(), - one); + Value one = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(1)); + mean = AtenAddScalarOp::create(rewriter, loc, reducedTy, mean, + op.getEps(), one); } // rsqrt(mean(x^2) + eps) - Value invRMS = rewriter.create(loc, reducedTy, mean); + Value invRMS = AtenRsqrtOp::create(rewriter, loc, reducedTy, mean); // rsqrt(mean(x^2) + eps) * x Value normalized = - rewriter.create(loc, inputTy, input, invRMS); + AtenMulTensorOp::create(rewriter, loc, inputTy, input, invRMS); // Optionally multiply by weight if provided Value weight = op.getWeight(); if (!isa(weight.getType())) { normalized = - rewriter.create(loc, outputTy, normalized, weight); + AtenMulTensorOp::create(rewriter, loc, outputTy, normalized, weight); } rewriter.replaceOp(op, normalized); return success(); @@ -7934,7 +7992,7 @@ class DecomposeAtenEmptyLikeOp : public OpRewritePattern { auto sizeListType = Torch::ListType::get(Torch::IntType::get(op.getContext())); Value sizeList = - rewriter.create(op.getLoc(), sizeListType, op.getSelf()); + AtenSizeOp::create(rewriter, op.getLoc(), sizeListType, op.getSelf()); FailureOr dtype = getDtypeFromOp(rewriter, op); if (failed(dtype)) { @@ -7960,10 +8018,10 @@ class DecomposeAtenArangeOp : public OpRewritePattern { // The AtenArangeOp doesn't have a start and step value. Therefore we set // them as default values 0 and 1, respectively. Value start, step; - start = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - step = rewriter.create(loc, - rewriter.getI64IntegerAttr(1)); + start = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(0)); + step = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(1)); rewriter.replaceOpWithNewOp( op, op.getType(), start, op.getEnd(), step, op.getDtype(), op.getLayout(), op.getDevice(), op.getPinMemory()); @@ -7982,8 +8040,8 @@ class DecomposeAtenArangeStartOp : public OpRewritePattern { // The AtenArangeStartOp doesn't have a step value. Therefore we set it as // default value 1. Value step; - step = rewriter.create(loc, - rewriter.getI64IntegerAttr(1)); + step = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(1)); rewriter.replaceOpWithNewOp( op, op.getType(), op.getStart(), op.getEnd(), step, op.getDtype(), op.getLayout(), op.getDevice(), op.getPinMemory()); @@ -8010,9 +8068,9 @@ class DecomposePrimsIotaOp : public OpRewritePattern { if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) return rewriter.notifyMatchFailure( op, "unimplemented: low must be a constant integer"); - auto endVal = rewriter.create( - loc, rewriter.getI64IntegerAttr(start + length * step)); - auto none = rewriter.create(loc); + auto endVal = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(start + length * step)); + auto none = ConstantNoneOp::create(rewriter, loc); rewriter.replaceOpWithNewOp( op, op.getType(), op.getStart(), endVal, op.getStep(), op.getDtype(), none, op.getDevice(), none); @@ -8029,8 +8087,8 @@ class DecomposeConstantTensorAllocLikeOp : public OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Value constVal = rewriter.create( - loc, rewriter.getI64IntegerAttr(fillVal)); + Value constVal = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(fillVal)); rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), constVal, op.getDtype(), op.getLayout(), op.getDevice(), op.getPinMemory(), op.getMemoryFormat()); @@ -8058,9 +8116,9 @@ class DecomposeAtenGroupNormOp : public OpRewritePattern { op, "unimplemented: num_groups must be a constant int"); Value cstZero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0)); Value cstOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); auto inputType = cast(input.getType()); if (!inputType.hasSizes()) @@ -8070,16 +8128,17 @@ class DecomposeAtenGroupNormOp : public OpRewritePattern { auto baseType = inputType.getWithSizesAndDtype( baseTypeSizes, inputType.getOptionalDtype()); - Value N = rewriter.create(loc, input, cstZero); - Value C = rewriter.create(loc, input, cstOne); - Value numElements = rewriter.create(loc, input); + Value N = AtenSizeIntOp::create(rewriter, loc, input, cstZero); + Value C = AtenSizeIntOp::create(rewriter, loc, input, cstOne); + Value numElements = AtenNumelOp::create(rewriter, loc, input); Value numElementsDivN = - rewriter.create(loc, numElements, N); - Value HxW = rewriter.create(loc, numElementsDivN, C); + AtenFloordivIntOp::create(rewriter, loc, numElements, N); + Value HxW = AtenFloordivIntOp::create(rewriter, loc, numElementsDivN, C); - AtenNativeGroupNormOp newOp = rewriter.create( - loc, ArrayRef{op.getResult().getType(), baseType, baseType}, - input, weight, bias, N, C, HxW, numGroups, eps); + AtenNativeGroupNormOp newOp = AtenNativeGroupNormOp::create( + rewriter, loc, + ArrayRef{op.getResult().getType(), baseType, baseType}, input, + weight, bias, N, C, HxW, numGroups, eps); rewriter.replaceOp(op, newOp.getResult0()); return success(); @@ -8113,24 +8172,24 @@ class DecomposeAtenNativeGroupNormOp op, "input/outputs tensor should have known sizes."); } - Value none = rewriter.create(loc); + Value none = ConstantNoneOp::create(rewriter, loc); Value cstZero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0)); Value cstOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); Value cstNegtiveOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); - Value cstTrue = rewriter.create(loc, true); - Value cstFalse = rewriter.create(loc, false); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(-1)); + Value cstTrue = Torch::ConstantBoolOp::create(rewriter, loc, true); + Value cstFalse = Torch::ConstantBoolOp::create(rewriter, loc, false); // GroupNorm requires the channel dimension (C) to be exactly divisible by // the number of groups. - Value channel = rewriter.create(loc, input, cstOne); + Value channel = AtenSizeIntOp::create(rewriter, loc, input, cstOne); Value remainder = - rewriter.create(loc, channel, numGroups); - Value eqOrNot = rewriter.create(loc, remainder, cstZero); - rewriter.create( - loc, eqOrNot, + AtenRemainderIntOp::create(rewriter, loc, channel, numGroups); + Value eqOrNot = AtenEqIntOp::create(rewriter, loc, remainder, cstZero); + RuntimeAssertOp::create( + rewriter, loc, eqOrNot, rewriter.getStringAttr("the number of channels must be divisible by " "the number of groups")); @@ -8156,49 +8215,50 @@ class DecomposeAtenNativeGroupNormOp : reshapeInputLastDim / numGroupsInt; reshapeInputShape.push_back(reshapeInputLastDim); - newShape.push_back(rewriter.create(loc, input, cstZero)); + newShape.push_back(AtenSizeIntOp::create(rewriter, loc, input, cstZero)); newShape.push_back(numGroups); newShape.push_back(cstNegtiveOne); Type reshapeInputType = inputType.getWithSizesAndDtype( reshapeInputShape, inputType.getOptionalDtype()); - Value reshapedInput = rewriter.create( - loc, reshapeInputType, input, - rewriter.create( - loc, Torch::ListType::get(IntType::get(context)), newShape)); + Value reshapedInput = AtenViewOp::create( + rewriter, loc, reshapeInputType, input, + PrimListConstructOp::create(rewriter, loc, + Torch::ListType::get(IntType::get(context)), + newShape)); // Now we proceed with the normalization steps across the 'groupSize' // Compute the mean and variance for each group - Value dimList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), + Value dimList = PrimListConstructOp::create( + rewriter, loc, + Torch::ListType::get(Torch::IntType::get(op.getContext())), ArrayRef{cstNegtiveOne}); reshapeInputShape[2] = 1; Type reductionType = inputType.getWithSizesAndDtype( reshapeInputShape, inputType.getOptionalDtype()); auto mean = - rewriter.create(loc, reductionType, reshapedInput, - /*dims=*/dimList, /*keepdim=*/cstTrue, - /*dtype=*/none); - auto var = - rewriter.create(loc, reductionType, reshapedInput, - /*dims=*/dimList, /*unbiased=*/cstFalse, - /*keepdim=*/cstTrue); + AtenMeanDimOp::create(rewriter, loc, reductionType, reshapedInput, + /*dims=*/dimList, /*keepdim=*/cstTrue, + /*dtype=*/none); + auto var = AtenVarDimOp::create(rewriter, loc, reductionType, reshapedInput, + /*dims=*/dimList, /*unbiased=*/cstFalse, + /*keepdim=*/cstTrue); // Compute the normalized output: (input - mean) * rsqrt(var + eps) auto varPlusEps = - rewriter.create(loc, reductionType, var, eps, - /*alpha=*/cstOne); - auto invStd = rewriter.create(loc, reductionType, varPlusEps); - auto inputSubMean = rewriter.create( - loc, reshapeInputType, reshapedInput, mean, /*alpha=*/cstOne); - auto normalizedOutput = rewriter.create( - loc, reshapeInputType, inputSubMean, invStd); + AtenAddScalarOp::create(rewriter, loc, reductionType, var, eps, + /*alpha=*/cstOne); + auto invStd = AtenRsqrtOp::create(rewriter, loc, reductionType, varPlusEps); + auto inputSubMean = AtenSubTensorOp::create( + rewriter, loc, reshapeInputType, reshapedInput, mean, /*alpha=*/cstOne); + auto normalizedOutput = AtenMulTensorOp::create( + rewriter, loc, reshapeInputType, inputSubMean, invStd); // Reshape normalized output back to the original input shape - auto inputShape = rewriter.create( - loc, Torch::ListType::get(IntType::get(context)), input); - auto reshapedOutput = rewriter.create( - loc, inputType, normalizedOutput, /*shape=*/inputShape); + auto inputShape = AtenSizeOp::create( + rewriter, loc, Torch::ListType::get(IntType::get(context)), input); + auto reshapedOutput = AtenViewOp::create( + rewriter, loc, inputType, normalizedOutput, /*shape=*/inputShape); // Apply weight and bias if they are not None // Reshape weight and bias to C,1,1,... @@ -8208,30 +8268,30 @@ class DecomposeAtenNativeGroupNormOp viewShape.push_back(cstOne); viewShapeInt.push_back(1); } - Value viewShapeSizeList = rewriter.create( - loc, ListType::get(IntType::get(context)), viewShape); + Value viewShapeSizeList = PrimListConstructOp::create( + rewriter, loc, ListType::get(IntType::get(context)), viewShape); Type viewType = inputType.getWithSizesAndDtype( viewShapeInt, inputType.getOptionalDtype()); Value groupNormOutput = reshapedOutput; if (!isa(weight.getType())) { - auto weightReshaped = rewriter.create( - loc, viewType, weight, /*shape=*/viewShapeSizeList); - groupNormOutput = rewriter.create( - loc, inputType, groupNormOutput, weightReshaped); + auto weightReshaped = AtenViewOp::create(rewriter, loc, viewType, weight, + /*shape=*/viewShapeSizeList); + groupNormOutput = AtenMulTensorOp::create( + rewriter, loc, inputType, groupNormOutput, weightReshaped); } if (!isa(bias.getType())) { - auto biasReshaped = rewriter.create( - loc, viewType, bias, /*shape=*/viewShapeSizeList); - groupNormOutput = rewriter.create( - loc, inputType, groupNormOutput, biasReshaped, - /*alpha=*/cstOne); + auto biasReshaped = AtenViewOp::create(rewriter, loc, viewType, bias, + /*shape=*/viewShapeSizeList); + groupNormOutput = AtenAddTensorOp::create(rewriter, loc, inputType, + groupNormOutput, biasReshaped, + /*alpha=*/cstOne); } Value squeezedMean = - rewriter.create(loc, meanType, mean, cstNegtiveOne); - Value squeezedRsqrtVar = rewriter.create( - loc, rsqrtVarType, invStd, cstNegtiveOne); + AtenSqueezeDimOp::create(rewriter, loc, meanType, mean, cstNegtiveOne); + Value squeezedRsqrtVar = AtenSqueezeDimOp::create( + rewriter, loc, rsqrtVarType, invStd, cstNegtiveOne); rewriter.replaceOp( op, ArrayRef{groupNormOutput, squeezedMean, squeezedRsqrtVar}); @@ -8287,10 +8347,11 @@ class DecomposeAtenNativeBatchNormOp op, "expected runningMean and runningVar to be rank 1"); Value zero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0)); Value one = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); - Value numFeatures = rewriter.create(loc, input, /*dim=*/one); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); + Value numFeatures = + AtenSizeIntOp::create(rewriter, loc, input, /*dim=*/one); // TODO: Add Runtime Asserts to check the shape of weight, bias, // runningMean and runningVar to be (numFeatures). @@ -8300,8 +8361,8 @@ class DecomposeAtenNativeBatchNormOp // 2. runningVar = runningVar.view(1, C, 1?, 1?, 1?) SmallVector runningStatsShape(inputRank, one); runningStatsShape[1] = numFeatures; - Value runningStatsSizeList = rewriter.create( - loc, ListType::get(IntType::get(context)), runningStatsShape); + Value runningStatsSizeList = PrimListConstructOp::create( + rewriter, loc, ListType::get(IntType::get(context)), runningStatsShape); SmallVector runningStatsShapeInt(inputRank, 1); runningStatsShapeInt[1] = @@ -8310,19 +8371,19 @@ class DecomposeAtenNativeBatchNormOp Type reshapeType = ValueTensorType::get( context, llvm::ArrayRef(runningStatsShapeInt), dtype); - runningMean = rewriter.create(loc, reshapeType, runningMean, - runningStatsSizeList); - runningVar = rewriter.create(loc, reshapeType, runningVar, - runningStatsSizeList); + runningMean = AtenViewOp::create(rewriter, loc, reshapeType, runningMean, + runningStatsSizeList); + runningVar = AtenViewOp::create(rewriter, loc, reshapeType, runningVar, + runningStatsSizeList); // normalizedInput = (input - runningMean) / (sqrt(runningVar + eps)). - Value inputSubMean = rewriter.create( - loc, input.getType(), input, runningMean, /*alpha=*/one); - Value varEps = rewriter.create( - loc, runningVar.getType(), runningVar, eps, /*alpha=*/one); - Value invStd = rewriter.create(loc, varEps.getType(), varEps); - Value normalizedInput = rewriter.create( - loc, inputSubMean.getType(), inputSubMean, invStd); + Value inputSubMean = AtenSubTensorOp::create( + rewriter, loc, input.getType(), input, runningMean, /*alpha=*/one); + Value varEps = AtenAddScalarOp::create(rewriter, loc, runningVar.getType(), + runningVar, eps, /*alpha=*/one); + Value invStd = AtenRsqrtOp::create(rewriter, loc, varEps.getType(), varEps); + Value normalizedInput = AtenMulTensorOp::create( + rewriter, loc, inputSubMean.getType(), inputSubMean, invStd); // The `weight` and `bias` must be reshaped to (1, C, 1?, 1?, 1?) to make it // broadcast-compatible with (N, C, D?, H?, W?). @@ -8335,31 +8396,32 @@ class DecomposeAtenNativeBatchNormOp std::optional weightRank = getTensorRank(weight); if (!weightRank || *weightRank != 1) return rewriter.notifyMatchFailure(op, "expected weight to be rank 1"); - weight = rewriter.create(loc, reshapeType, weight, - runningStatsSizeList); - batchNormOutput = rewriter.create( - loc, batchNormOutput.getType(), batchNormOutput, weight); + weight = AtenViewOp::create(rewriter, loc, reshapeType, weight, + runningStatsSizeList); + batchNormOutput = AtenMulTensorOp::create( + rewriter, loc, batchNormOutput.getType(), batchNormOutput, weight); } if (!isa(bias.getType())) { // Rank of `bias` must be exactly 1. std::optional biasRank = getTensorRank(bias); if (!biasRank || *biasRank != 1) return rewriter.notifyMatchFailure(op, "expected bias to be rank 1"); - bias = rewriter.create(loc, reshapeType, bias, - runningStatsSizeList); - batchNormOutput = rewriter.create( - loc, batchNormOutput.getType(), batchNormOutput, bias, /*alpha=*/one); + bias = AtenViewOp::create(rewriter, loc, reshapeType, bias, + runningStatsSizeList); + batchNormOutput = + AtenAddTensorOp::create(rewriter, loc, batchNormOutput.getType(), + batchNormOutput, bias, /*alpha=*/one); } // The `mean` and `invstd` outputs are empty tensors in inference mode. - Value zeroList = rewriter.create( - loc, Torch::ListType::get(zero.getType()), zero); - Value none = rewriter.create(loc); - Value emptyMeanTensor = rewriter.create( - loc, op.getType(1), zeroList, /*dtype=*/none, /*layout=*/none, + Value zeroList = PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(zero.getType()), zero); + Value none = ConstantNoneOp::create(rewriter, loc); + Value emptyMeanTensor = AtenEmptyMemoryFormatOp::create( + rewriter, loc, op.getType(1), zeroList, /*dtype=*/none, /*layout=*/none, /*device=*/none, /*pinMemory=*/none, /*memoryFormat=*/none); - Value emptyInvStdTensor = rewriter.create( - loc, op.getType(2), zeroList, /*dtype=*/none, /*layout=*/none, + Value emptyInvStdTensor = AtenEmptyMemoryFormatOp::create( + rewriter, loc, op.getType(2), zeroList, /*dtype=*/none, /*layout=*/none, /*device=*/none, /*pinMemory=*/none, /*memoryFormat=*/none); rewriter.replaceOp(op, @@ -8454,8 +8516,8 @@ class DecomposeAtenFullOp : public OpRewritePattern { auto dtype = getTypeForTorchType(op.getContext(), op.getFillValue().getType()); Type tensorType = outTy.getWithSizesAndDtype(llvm::ArrayRef(empty), dtype); - Value fillVal = rewriter.create(loc, tensorType, - op.getFillValue()); + Value fillVal = PrimNumToTensorScalarOp::create(rewriter, loc, tensorType, + op.getFillValue()); fillVal = convertTensorToDtype(rewriter, loc, fillVal, outTy.getDtype()); rewriter.replaceOpWithNewOp(op, op.getType(), fillVal, op.getSize()); @@ -8490,7 +8552,7 @@ class DecomposeAtenLinearOp : public OpRewritePattern { Type transposeType = weightType.getWithSizesAndDtype( llvm::ArrayRef(transposeShape), weightType.getOptionalDtype()); Value transposeWeight = - rewriter.create(loc, transposeType, weight); + AtenTOp::create(rewriter, loc, transposeType, weight); return transposeWeight; }; @@ -8520,10 +8582,10 @@ class DecomposeAtenLinearOp : public OpRewritePattern { return rewriter.notifyMatchFailure(op, "expected weight to be a rank 2"); - Value matmul = rewriter.create(loc, op.getType(), input, - transposeWeight()); + Value matmul = AtenMatmulOp::create(rewriter, loc, op.getType(), input, + transposeWeight()); Value alpha = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); rewriter.replaceOpWithNewOp(op, op.getType(), matmul, op.getBias(), alpha); return success(); @@ -8555,8 +8617,8 @@ class DecomposeAtenMishOp : public OpRewritePattern { Value threshold = getConstantWithGivenDtypeAndValue(rewriter, loc, 20.0, dType); Value softplusOp = - rewriter.create(loc, type, input, beta, threshold); - Value tanhOp = rewriter.create(loc, type, softplusOp); + AtenSoftplusOp::create(rewriter, loc, type, input, beta, threshold); + Value tanhOp = AtenTanhOp::create(rewriter, loc, type, softplusOp); rewriter.replaceOpWithNewOp(op, type, input, tanhOp); return success(); } @@ -8579,8 +8641,8 @@ class DecomposeAtenFullLikeOp : public OpRewritePattern { auto dtype = getTypeForTorchType(op.getContext(), op.getFillValue().getType()); Type tensorType = outTy.getWithSizesAndDtype(llvm::ArrayRef(empty), dtype); - Value fillVal = rewriter.create( - op.getLoc(), tensorType, op.getFillValue()); + Value fillVal = PrimNumToTensorScalarOp::create( + rewriter, op.getLoc(), tensorType, op.getFillValue()); fillVal = convertTensorToDtype(rewriter, op.getLoc(), fillVal, outTy.getDtype()); rewriter.replaceOpWithNewOp(op, op.getType(), fillVal, @@ -8625,7 +8687,7 @@ class DecomposeAtenExpandAsOp : public OpRewritePattern { auto sizeListType = Torch::ListType::get(Torch::IntType::get(op.getContext())); Value sizeList = - rewriter.create(op.getLoc(), sizeListType, op.getOther()); + AtenSizeOp::create(rewriter, op.getLoc(), sizeListType, op.getOther()); rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), sizeList); return success(); @@ -8648,8 +8710,8 @@ class DecomposeAten_ToCopyOp : public OpRewritePattern { Type resultDtype = resultType.getDtype(); Value zero = getConstantWithGivenDtypeAndValue(rewriter, op.getLoc(), 0.0, resultDtype); - Value emptyTensor = rewriter.create( - op.getLoc(), op.getType(), op.getSelf(), zero, op.getDtype(), + Value emptyTensor = AtenFullLikeOp::create( + rewriter, op.getLoc(), op.getType(), op.getSelf(), zero, op.getDtype(), op.getLayout(), op.getDevice(), op.getPinMemory(), op.getMemoryFormat()); rewriter.replaceOpWithNewOp(op, op.getType(), emptyTensor, @@ -8692,7 +8754,7 @@ class DecomposeAtenNewEmptyOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenNewEmptyOp op, PatternRewriter &rewriter) const override { - Value noneVal = rewriter.create(op.getLoc()); + Value noneVal = ConstantNoneOp::create(rewriter, op.getLoc()); FailureOr dtype = getDtypeFromOp(rewriter, op); if (failed(dtype)) { return rewriter.notifyMatchFailure( @@ -8721,8 +8783,8 @@ class DecomposeAtenPadOp : public OpRewritePattern { if (isa(value.getType())) return rewriter.notifyMatchFailure(op, "optional type not supported"); if (isa(value.getType())) - value = rewriter.create( - op.getLoc(), rewriter.getF64FloatAttr(0)); + value = Torch::ConstantFloatOp::create(rewriter, op.getLoc(), + rewriter.getF64FloatAttr(0)); rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), op.getPad(), value); @@ -8758,8 +8820,8 @@ class DecomposeAtenPadOp : public OpRewritePattern { if (usefulPadIndexEnd < padValues.size()) { ArrayRef usefulPadValues(padValues.begin(), padValues.begin() + usefulPadIndexEnd); - usefulPads = rewriter.create( - op.getLoc(), + usefulPads = PrimListConstructOp::create( + rewriter, op.getLoc(), rewriter.getType(rewriter.getType()), usefulPadValues); } @@ -8878,11 +8940,11 @@ class DecomposeAtenToPrimDeviceOp // Device information isn't relevant to torch-mlir, so we can drop that info // here. auto loc = op.getLoc(); - Value constNone = rewriter.create(loc); + Value constNone = ConstantNoneOp::create(rewriter, loc); Value dtype = op.getDtype(); if (isa(dtype.getType())) { - dtype = rewriter.create(loc, op.getSelf()); + dtype = Torch::PrimDtypeOp::create(rewriter, loc, op.getSelf()); } rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), dtype, op.getNonBlocking(), @@ -8934,20 +8996,20 @@ class DecomposeAtenAdaptivePool1dOp : public OpRewritePattern { return rewriter.notifyMatchFailure(op, "expected input to have a rank"); } unsigned rank = *maybeRank; - Value sizeDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(rank - 1)); + Value sizeDim = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(rank - 1)); Value inputSize = rewriter.createOrFold(loc, input, sizeDim); SmallVector outputShapeSizesTorchInt; getListConstructElements(op.getOutputSize(), outputShapeSizesTorchInt); Value outputSize = outputShapeSizesTorchInt[0]; - Value constantOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - Value constantZero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - Value constantFalse = rewriter.create(loc, false); - Value constantTrue = rewriter.create(loc, true); + Value constantOne = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(1)); + Value constantZero = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(0)); + Value constantFalse = Torch::ConstantBoolOp::create(rewriter, loc, false); + Value constantTrue = Torch::ConstantBoolOp::create(rewriter, loc, true); int64_t outputSizeInt; if (!matchPattern(outputSize, m_TorchConstantInt(&outputSizeInt))) { @@ -8960,22 +9022,23 @@ class DecomposeAtenAdaptivePool1dOp : public OpRewritePattern { kernelSize.push_back(inputSize); } else { if (!isAssumingStrictSymbolicShapes(rewriter)) { - Value cond = rewriter.create(loc, inputSize, outputSize); - rewriter.create( - loc, cond, + Value cond = AtenEqIntOp::create(rewriter, loc, inputSize, outputSize); + RuntimeAssertOp::create( + rewriter, loc, cond, "unimplemented: only support cases where input and output size are " "equal for non-unit output size"); } kernelSize.push_back(constantOne); } - Value kernelSizeList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(context)), kernelSize); - Value strideList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(context)), + Value kernelSizeList = PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(Torch::IntType::get(context)), + kernelSize); + Value strideList = PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(Torch::IntType::get(context)), ValueRange{constantOne}); - Value paddingSizeList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(context)), + Value paddingSizeList = PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(Torch::IntType::get(context)), ValueRange{constantZero}); if constexpr (std::is_same_v) { @@ -8984,18 +9047,18 @@ class DecomposeAtenAdaptivePool1dOp : public OpRewritePattern { /*ceil_mode=*/constantFalse, /*count_include_pad=*/constantTrue); return success(); } else if constexpr (std::is_same_v) { - Value dilationList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(context)), + Value dilationList = PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(Torch::IntType::get(context)), ValueRange{constantOne}); if (op.getResult(1).use_empty()) { - auto maxPool = rewriter.create( - loc, op.getType(0), input, kernelSizeList, strideList, + auto maxPool = AtenMaxPool1dOp::create( + rewriter, loc, op.getType(0), input, kernelSizeList, strideList, paddingSizeList, dilationList, /*ceil_mode=*/constantFalse); rewriter.replaceOp(op, {maxPool.getResult(), Value()}); } else { - auto maxPool = rewriter.create( - loc, op.getType(0), op.getType(1), input, kernelSizeList, + auto maxPool = AtenMaxPool1dWithIndicesOp::create( + rewriter, loc, op.getType(0), op.getType(1), input, kernelSizeList, strideList, paddingSizeList, dilationList, /*ceil_mode=*/constantFalse); rewriter.replaceOp(op, maxPool.getResults()); @@ -9037,12 +9100,12 @@ class DecomposeAtenAdaptivePool2dOp : public OpRewritePattern { } unsigned rank = *maybeRank; SmallVector inputHW; - Value dimH = rewriter.create( - loc, rewriter.getI64IntegerAttr(rank - 2)); + Value dimH = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(rank - 2)); inputHW.push_back( /*inH=*/rewriter.createOrFold(loc, input, dimH)); - Value dimW = rewriter.create( - loc, rewriter.getI64IntegerAttr(rank - 1)); + Value dimW = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(rank - 1)); inputHW.push_back( /*inW=*/rewriter.createOrFold(loc, input, dimW)); @@ -9054,78 +9117,82 @@ class DecomposeAtenAdaptivePool2dOp : public OpRewritePattern { // the stride/kernel size is not fixed. // The following logic of stride/kernel size derivation is consistent // with torch/_decomp/decomposations.py:adaptive_avg_pool2d. - Value constantZero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - Value constantOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - Value constantFalse = rewriter.create(loc, false); - Value constantTrue = rewriter.create(loc, true); - Value constantNone = rewriter.create(loc); + Value constantZero = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(0)); + Value constantOne = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(1)); + Value constantFalse = Torch::ConstantBoolOp::create(rewriter, loc, false); + Value constantTrue = Torch::ConstantBoolOp::create(rewriter, loc, true); + Value constantNone = Torch::ConstantNoneOp::create(rewriter, loc); SmallVector strideSize; SmallVector kernelSize; for (unsigned i = 0; i < inputHW.size(); i++) { - Value remainder = rewriter.create( - loc, inputHW[i], outputShapeSizesTorchInt[i]); + Value remainder = AtenRemainderIntOp::create(rewriter, loc, inputHW[i], + outputShapeSizesTorchInt[i]); // Filter cases with fixed stride size. - Value cond1 = rewriter.create( - loc, outputShapeSizesTorchInt[i], - rewriter.create( - loc, remainder, - rewriter.create( - loc, outputShapeSizesTorchInt[i], constantOne))); - rewriter.create( - loc, cond1, + Value cond1 = Torch::AtenGtIntOp::create( + rewriter, loc, outputShapeSizesTorchInt[i], + Torch::AtenMulIntOp::create( + rewriter, loc, remainder, + Torch::AtenSubIntOp::create( + rewriter, loc, outputShapeSizesTorchInt[i], constantOne))); + RuntimeAssertOp::create( + rewriter, loc, cond1, "unimplemented: only support cases with fixed stride size."); // Filter cases with fixed kernel size. // cond2: whether input_size % output_size == 0. Value cond2 = - rewriter.create(loc, remainder, constantZero); + Torch::AtenEqIntOp::create(rewriter, loc, remainder, constantZero); // cond3: whether output_size % (input_size % output_size) == 0. // To avoid potential crash (eg. tosa) happens,choose to mod 1 (add // offset) when remainder equals 0, which has no side effect on // effectiveness. - Value offset = rewriter.create( - loc, rewriter.create( - loc, rewriter.create(loc, remainder))); + Value offset = Torch::AtenIntBoolOp::create( + rewriter, loc, + Torch::Aten__Not__Op::create( + rewriter, loc, + Torch::AtenBoolIntOp::create(rewriter, loc, remainder))); Value remainder_not_zero = - rewriter.create(loc, remainder, offset); - Value cond3 = rewriter.create( - loc, - rewriter.create( - loc, outputShapeSizesTorchInt[i], remainder_not_zero), + Torch::AtenAddIntOp::create(rewriter, loc, remainder, offset); + Value cond3 = Torch::AtenEqIntOp::create( + rewriter, loc, + Torch::AtenRemainderIntOp::create( + rewriter, loc, outputShapeSizesTorchInt[i], remainder_not_zero), constantZero); - Value cond = rewriter.create(loc, cond2, cond3); + Value cond = Torch::Aten__Or__BoolOp::create(rewriter, loc, cond2, cond3); - rewriter.create( - loc, cond, + RuntimeAssertOp::create( + rewriter, loc, cond, "unimplemented: only support cases with fixed kernel size."); - Value stride = rewriter.create( - loc, inputHW[i], outputShapeSizesTorchInt[i]); + Value stride = Torch::AtenFloordivIntOp::create( + rewriter, loc, inputHW[i], outputShapeSizesTorchInt[i]); strideSize.emplace_back(stride); - Value kernel = rewriter.create( - loc, inputHW[i], outputShapeSizesTorchInt[i]); + Value kernel = Torch::AtenFloordivIntOp::create( + rewriter, loc, inputHW[i], outputShapeSizesTorchInt[i]); // When remainder equals 0, it is no need for kernel to add 1 // and just keep the same as stride, otherwise it is necessary // to add 1 (torch/_decomp/decomposations.py:adaptive_avg_pool2d). - Value boolMod = rewriter.create(loc, remainder); - Value intMod = rewriter.create(loc, boolMod); + Value boolMod = Torch::AtenBoolIntOp::create(rewriter, loc, remainder); + Value intMod = Torch::AtenIntBoolOp::create(rewriter, loc, boolMod); - kernel = rewriter.create(loc, kernel, intMod); + kernel = Torch::AtenAddIntOp::create(rewriter, loc, kernel, intMod); kernelSize.emplace_back(kernel); } - Value kernelSizeList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(context)), kernelSize); - Value strideList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(context)), strideSize); - Value paddingSizeList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(context)), + Value kernelSizeList = PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(Torch::IntType::get(context)), + kernelSize); + Value strideList = PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(Torch::IntType::get(context)), + strideSize); + Value paddingSizeList = PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(Torch::IntType::get(context)), ValueRange{constantZero, constantZero}); if constexpr (std::is_same_v) { @@ -9135,17 +9202,17 @@ class DecomposeAtenAdaptivePool2dOp : public OpRewritePattern { /*divisorOverride=*/constantNone); return success(); } else if constexpr (std::is_same_v) { - Value dilationList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(context)), + Value dilationList = PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(Torch::IntType::get(context)), ValueRange{constantOne, constantOne}); if (op.getResult(1).use_empty()) { - auto maxPool = rewriter.create( - loc, op.getType(0), input, kernelSizeList, strideList, + auto maxPool = AtenMaxPool2dOp::create( + rewriter, loc, op.getType(0), input, kernelSizeList, strideList, paddingSizeList, dilationList, /*ceil_mode=*/constantFalse); rewriter.replaceOp(op, {maxPool.getResult(), Value()}); } else { - auto maxPool = rewriter.create( - loc, op.getType(0), op.getType(1), input, kernelSizeList, + auto maxPool = AtenMaxPool2dWithIndicesOp::create( + rewriter, loc, op.getType(0), op.getType(1), input, kernelSizeList, strideList, paddingSizeList, dilationList, /*ceil_mode=*/constantFalse); rewriter.replaceOp(op, maxPool.getResults()); @@ -9164,7 +9231,7 @@ class DecomposeAtenClampMinOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenClampMinOp op, PatternRewriter &rewriter) const override { - Value constantNone = rewriter.create(op.getLoc()); + Value constantNone = Torch::ConstantNoneOp::create(rewriter, op.getLoc()); rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), op.getMin(), /*max=*/constantNone); return success(); @@ -9179,7 +9246,7 @@ class DecomposeAtenClampMinTensorOp using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenClampMinTensorOp op, PatternRewriter &rewriter) const override { - Value constantNone = rewriter.create(op.getLoc()); + Value constantNone = Torch::ConstantNoneOp::create(rewriter, op.getLoc()); rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), op.getMin(), /*max=*/constantNone); return success(); @@ -9193,7 +9260,7 @@ class DecomposeAtenClampMaxOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenClampMaxOp op, PatternRewriter &rewriter) const override { - Value constantNone = rewriter.create(op.getLoc()); + Value constantNone = Torch::ConstantNoneOp::create(rewriter, op.getLoc()); rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), /*min=*/constantNone, op.getMax()); return success(); @@ -9206,8 +9273,8 @@ class DecomposeAtenRad2degOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenRad2degOp op, PatternRewriter &rewriter) const override { - Value constant180OverPi = rewriter.create( - op.getLoc(), rewriter.getF64FloatAttr(180 / 3.14159)); + Value constant180OverPi = Torch::ConstantFloatOp::create( + rewriter, op.getLoc(), rewriter.getF64FloatAttr(180 / 3.14159)); rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), constant180OverPi); return success(); @@ -9234,42 +9301,44 @@ class DecomposeAtenCosineSimilarityOp Type dtype = cast(x1.getType()).getOptionalDtype(); Type broadcastType = ValueTensorType::get( op.getContext(), llvm::ArrayRef(indexBroadcastShapeInt), dtype); - Value indexBroadcastShapeTorchList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), + Value indexBroadcastShapeTorchList = PrimListConstructOp::create( + rewriter, loc, + Torch::ListType::get(Torch::IntType::get(op.getContext())), indexBroadcastShapeValue); - x1 = rewriter.create(loc, broadcastType, x1, - indexBroadcastShapeTorchList); - x2 = rewriter.create(loc, broadcastType, x2, - indexBroadcastShapeTorchList); + x1 = AtenBroadcastToOp::create(rewriter, loc, broadcastType, x1, + indexBroadcastShapeTorchList); + x2 = AtenBroadcastToOp::create(rewriter, loc, broadcastType, x2, + indexBroadcastShapeTorchList); // Compute the mul of A and B Value dotProduct = - rewriter.create(loc, broadcastType, x1, x2); - Value cstFalse = rewriter.create(loc, false); - Value cstNone = rewriter.create(loc); - Value dimList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(op->getContext())), + AtenMulTensorOp::create(rewriter, loc, broadcastType, x1, x2); + Value cstFalse = Torch::ConstantBoolOp::create(rewriter, loc, false); + Value cstNone = Torch::ConstantNoneOp::create(rewriter, loc); + Value dimList = PrimListConstructOp::create( + rewriter, loc, + Torch::ListType::get(Torch::IntType::get(op->getContext())), ValueRange{dim}); - Value sumDotProduct = rewriter.create( - loc, op.getType(), /*self=*/dotProduct, /*dim=*/dimList, + Value sumDotProduct = Torch::AtenSumDimIntListOp::create( + rewriter, loc, op.getType(), /*self=*/dotProduct, /*dim=*/dimList, /*keepdim=*/cstFalse, /*dtype=*/cstNone); // Compute the norm of A and B - Value ord = rewriter.create( - loc, rewriter.getF64FloatAttr(2.0)); - Value normA = rewriter.create( - loc, op.getType(), x1, ord, dimList, /*keepdim=*/cstFalse, + Value ord = Torch::ConstantFloatOp::create(rewriter, loc, + rewriter.getF64FloatAttr(2.0)); + Value normA = AtenLinalgVectorNormOp::create( + rewriter, loc, op.getType(), x1, ord, dimList, /*keepdim=*/cstFalse, /*dtype=*/cstNone); - Value normB = rewriter.create( - loc, op.getType(), x2, ord, dimList, /*keepdim=*/cstFalse, + Value normB = AtenLinalgVectorNormOp::create( + rewriter, loc, op.getType(), x2, ord, dimList, /*keepdim=*/cstFalse, /*dtype=*/cstNone); // Compute the product of the norms Value normProduct = - rewriter.create(loc, op.getType(), normA, normB); - Value normProductClamp = rewriter.create( - loc, op.getType(), normProduct, op.getEps(), /*max=*/cstNone); + AtenMulTensorOp::create(rewriter, loc, op.getType(), normA, normB); + Value normProductClamp = AtenClampOp::create( + rewriter, loc, op.getType(), normProduct, op.getEps(), /*max=*/cstNone); // Compute the final cosine similarity by division rewriter.replaceOpWithNewOp( op, op.getType(), sumDotProduct, normProductClamp); @@ -9293,9 +9362,9 @@ class DecomposeAtenTruncOp : public OpRewritePattern { } if (isa(resultTy.getDtype())) { - Value sign = rewriter.create(loc, resultTy, self); - Value abs = rewriter.create(loc, resultTy, self); - Value floor = rewriter.create(loc, resultTy, abs); + Value sign = AtenSgnOp::create(rewriter, loc, resultTy, self); + Value abs = AtenAbsOp::create(rewriter, loc, resultTy, self); + Value floor = AtenFloorOp::create(rewriter, loc, resultTy, abs); rewriter.replaceOpWithNewOp(op, resultTy, sign, floor); return success(); } @@ -9338,19 +9407,19 @@ class DecomposeAtenSignbitOp : public OpRewritePattern { mlir::IntegerType intType = rewriter.getIntegerType( operandTy.getDtype().getIntOrFloatBitWidth(), /*isSigned*/ true); Value dtype = getDtypeIntValueForType(rewriter, loc, intType); - Value view = rewriter.create( - loc, + Value view = AtenViewDtypeOp::create( + rewriter, loc, operandTy.getWithSizesAndDtype(operandTy.getOptionalSizes(), intType), self, dtype); Value zero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); - Value shift = rewriter.create(loc, resultTy, view, zero); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0)); + Value shift = AtenLtScalarOp::create(rewriter, loc, resultTy, view, zero); rewriter.replaceOp(op, shift); return success(); } else if (isa(operandTy.getDtype())) { Value zero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); - Value shift = rewriter.create(loc, resultTy, self, zero); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0)); + Value shift = AtenLtScalarOp::create(rewriter, loc, resultTy, self, zero); rewriter.replaceOp(op, shift); } return failure(); @@ -9369,8 +9438,8 @@ class DecomposeAtenFracOp : public OpRewritePattern { auto resultTy = op.getType(); Value one = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); - Value trunc = rewriter.create(loc, resultTy, self); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); + Value trunc = AtenTruncOp::create(rewriter, loc, resultTy, self); rewriter.replaceOpWithNewOp(op, resultTy, self, trunc, /*alpha=*/one); return success(); @@ -9392,13 +9461,13 @@ class DecomposeAtenCopysignTensorOp auto otherTy = cast(other.getType()); auto resultTy = op.getType(); - Value signbit = rewriter.create( - loc, + Value signbit = AtenSignbitOp::create( + rewriter, loc, otherTy.getWithSizesAndDtype(otherTy.getOptionalSizes(), rewriter.getI1Type()), other); - Value abs = rewriter.create(loc, selfTy, self); - Value neg = rewriter.create(loc, selfTy, abs); + Value abs = AtenAbsOp::create(rewriter, loc, selfTy, self); + Value neg = AtenNegOp::create(rewriter, loc, selfTy, abs); rewriter.replaceOpWithNewOp(op, resultTy, signbit, neg, abs); return success(); @@ -9422,11 +9491,11 @@ class DecomposeAtenLdexpTensorOp : public OpRewritePattern { return rewriter.notifyMatchFailure(op, "result must have dtype"); } - Value exp2 = rewriter.create( - loc, - resultTy.getWithSizesAndDtype(otherTy.getOptionalSizes(), - resultTy.getDtype()), - other); + Value exp2 = + AtenExp2Op::create(rewriter, loc, + resultTy.getWithSizesAndDtype( + otherTy.getOptionalSizes(), resultTy.getDtype()), + other); rewriter.replaceOpWithNewOp(op, resultTy, self, exp2); return success(); } @@ -9449,19 +9518,20 @@ class DecomposeAtenFmodTensorOp : public OpRewritePattern { } if (isa(resultTy.getDtype())) { - Value div = rewriter.create(loc, resultTy, self, other); - Value mul = rewriter.create(loc, resultTy, div, other); + Value div = AtenDivTensorOp::create(rewriter, loc, resultTy, self, other); + Value mul = AtenMulTensorOp::create(rewriter, loc, resultTy, div, other); Value alpha = - rewriter.create(loc, rewriter.getF64FloatAttr(1)); + ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(1)); rewriter.replaceOpWithNewOp(op, resultTy, self, mul, alpha); return success(); } else if (isa(resultTy.getDtype())) { - Value div = rewriter.create(loc, resultTy, self, other); - Value trunc = rewriter.create(loc, resultTy, div); - Value mul = rewriter.create(loc, resultTy, trunc, other); + Value div = AtenDivTensorOp::create(rewriter, loc, resultTy, self, other); + Value trunc = AtenTruncOp::create(rewriter, loc, resultTy, div); + Value mul = + AtenMulTensorOp::create(rewriter, loc, resultTy, trunc, other); Value alpha = - rewriter.create(loc, rewriter.getF64FloatAttr(1)); + ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(1)); rewriter.replaceOpWithNewOp(op, resultTy, self, mul, alpha); return success(); @@ -9479,10 +9549,10 @@ class DecomposeAtenBaddbmmOp : public OpRewritePattern { LogicalResult matchAndRewrite(AtenBaddbmmOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Value bmm = rewriter.create(loc, op.getType(), op.getBatch1(), - op.getBatch2()); - Value alphaTimesBmm = - rewriter.create(loc, op.getType(), bmm, op.getAlpha()); + Value bmm = AtenBmmOp::create(rewriter, loc, op.getType(), op.getBatch1(), + op.getBatch2()); + Value alphaTimesBmm = AtenMulScalarOp::create(rewriter, loc, op.getType(), + bmm, op.getAlpha()); Value input = op.getSelf(); BaseTensorType inputType = cast(input.getType()); BaseTensorType resultType = @@ -9508,7 +9578,7 @@ class DecomposeAtenFloorDivideOp : public OpRewritePattern { // PyTorch aten.floorDivide is a misnomer because it actually rounds // the quotient towards zero instead of taking its floor. Value cstStrFloor = - rewriter.create(op.getLoc(), "floor"); + Torch::ConstantStrOp::create(rewriter, op.getLoc(), "floor"); rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), op.getOther(), /*roundingMode=*/cstStrFloor); @@ -9524,7 +9594,7 @@ class DecomposeAtenFloorDivideScalarOp LogicalResult matchAndRewrite(AtenFloorDivideScalarOp op, PatternRewriter &rewriter) const override { Value cstStrFloor = - rewriter.create(op.getLoc(), "floor"); + Torch::ConstantStrOp::create(rewriter, op.getLoc(), "floor"); rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), op.getOther(), /*roundingMode=*/cstStrFloor); @@ -9551,11 +9621,12 @@ class DecomposeAtenNumpyTOp : public OpRewritePattern { SmallVector dimListInts(llvm::reverse( llvm::iota_range(0, inputRank, /*inclusive=*/false))); for (int dimListInt : dimListInts) { - dimListElements.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(dimListInt))); + dimListElements.push_back(Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(dimListInt))); } - Value dimList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(op->getContext())), + Value dimList = PrimListConstructOp::create( + rewriter, loc, + Torch::ListType::get(Torch::IntType::get(op->getContext())), dimListElements); rewriter.replaceOpWithNewOp(op, op.getType(), self, dimList); return success(); @@ -9614,10 +9685,11 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter, } if (isNoneOrEmpty) { for (unsigned i = 0; i < inputRank; i++) - dimListElements.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(i))); - dimList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), + dimListElements.push_back(Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(i))); + dimList = PrimListConstructOp::create( + rewriter, loc, + Torch::ListType::get(Torch::IntType::get(op.getContext())), dimListElements); } Type meanDimResultType = inputTensorTy; @@ -9627,56 +9699,58 @@ static LogicalResult calculateVariance(OpTy op, PatternRewriter &rewriter, dimListElements[i], /*keepDim=*/true); - Value constantNone = rewriter.create(loc); - Value constantTrue = rewriter.create(loc, true); - Value meanAlongDims = rewriter.create( - loc, meanDimResultType, self, dimList, /*keepDim=*/constantTrue, + Value constantNone = ConstantNoneOp::create(rewriter, loc); + Value constantTrue = ConstantBoolOp::create(rewriter, loc, true); + Value meanAlongDims = AtenMeanDimOp::create( + rewriter, loc, meanDimResultType, self, dimList, /*keepDim=*/constantTrue, /*dtype=*/constantNone); Value subMean = createTensorSub(rewriter, loc, inputTensorTy, self, meanAlongDims); - Value square = rewriter.create(loc, inputTensorTy, subMean); + Value square = AtenSquareOp::create(rewriter, loc, inputTensorTy, subMean); if (!unbiased) { - Value result = rewriter.create( - loc, newOutputType, square, dimList, keepDim, /*dtype=*/constantNone); + Value result = + AtenMeanDimOp::create(rewriter, loc, newOutputType, square, dimList, + keepDim, /*dtype=*/constantNone); result = convertTensorToDtype(rewriter, loc, result, outputTensorType.getDtype()); rewriter.replaceOp(op, result); return success(); } // Divide the square sum by productDimSize - correction. - Value squareSum = rewriter.create( - loc, newOutputType, square, dimList, keepDim, /*dtype=*/constantNone); + Value squareSum = + AtenSumDimIntListOp::create(rewriter, loc, newOutputType, square, dimList, + keepDim, /*dtype=*/constantNone); // `productDimSize` is product of sizes of dimensions to be reduced. - Value constantOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value constantOne = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(1)); Value productDimSize = constantOne; for (Value dim : dimListElements) { - Value dimSize = rewriter.create(loc, self, dim); + Value dimSize = AtenSizeIntOp::create(rewriter, loc, self, dim); productDimSize = - rewriter.create(loc, productDimSize, dimSize); + AtenMulIntOp::create(rewriter, loc, productDimSize, dimSize); } - productDimSize = rewriter.create(loc, productDimSize); - constantOne = rewriter.create( - loc, rewriter.getF64FloatAttr(1.0)); - Value cstCorrection = rewriter.create( - loc, rewriter.getF64FloatAttr(correction)); + productDimSize = AtenFloatScalarOp::create(rewriter, loc, productDimSize); + constantOne = Torch::ConstantFloatOp::create(rewriter, loc, + rewriter.getF64FloatAttr(1.0)); + Value cstCorrection = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(correction)); // The `correction` value should be less than or equal to `productDimSize + // 1`. if (!isAssumingStrictSymbolicShapes(rewriter)) { - Value productDimSizePlusOne = rewriter.create( - loc, productDimSize.getType(), productDimSize, constantOne); - Value cond = rewriter.create(loc, productDimSizePlusOne, - cstCorrection); - rewriter.create( - loc, cond, + Value productDimSizePlusOne = AtenAddOp::create( + rewriter, loc, productDimSize.getType(), productDimSize, constantOne); + Value cond = AtenGeFloatOp::create(rewriter, loc, productDimSizePlusOne, + cstCorrection); + RuntimeAssertOp::create( + rewriter, loc, cond, "correction value should be less than or equal to productDimSize + 1"); } Value productDimSizeSubCorrection = - rewriter.create(loc, productDimSize, cstCorrection); - Value result = rewriter.create(loc, newOutputType, squareSum, - productDimSizeSubCorrection); + AtenSubFloatOp::create(rewriter, loc, productDimSize, cstCorrection); + Value result = AtenDivScalarOp::create( + rewriter, loc, newOutputType, squareSum, productDimSizeSubCorrection); result = convertTensorToDtype(rewriter, loc, result, outputTensorType.getDtype()); rewriter.replaceOp(op, result); @@ -9768,9 +9842,9 @@ class DecomposeAtenSelectScatterOp Value src = op.getSrc(); Value one = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); Value startPlusOne = - rewriter.create(loc, one.getType(), start, one); + AtenAddIntOp::create(rewriter, loc, one.getType(), start, one); auto unsqueezedInfo = unsqueezeTensor(rewriter, op, src, dim); if (failed(unsqueezedInfo)) { @@ -9826,7 +9900,7 @@ class DecomposeAtenLiftFreshCopyOp using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenLiftFreshCopyOp op, PatternRewriter &rewriter) const override { - Value constantNone = rewriter.create(op.getLoc()); + Value constantNone = ConstantNoneOp::create(rewriter, op.getLoc()); rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), /*memoryFormat=*/constantNone); return success(); @@ -9862,22 +9936,23 @@ class DecomposeAtenMseLossOp : public OpRewritePattern { Value sub = createTensorSub(rewriter, loc, subType, op.getSelf(), op.getTarget()); - Value result = rewriter.create(loc, subType, sub); + Value result = AtenSquareOp::create(rewriter, loc, subType, sub); if (reductionType == torch_upstream::Reduction::None) { rewriter.replaceOp(op, result); return success(); } - Value cstFalse = rewriter.create(loc, false); - Value cstNone = rewriter.create(loc); + Value cstFalse = Torch::ConstantBoolOp::create(rewriter, loc, false); + Value cstNone = Torch::ConstantNoneOp::create(rewriter, loc); if (reductionType == torch_upstream::Reduction::Mean) - result = rewriter.create(loc, resultType, result, - /*dim=*/cstNone, - /*keepdim=*/cstFalse, - /*dtype=*/cstNone); + result = AtenMeanDimOp::create(rewriter, loc, resultType, result, + /*dim=*/cstNone, + /*keepdim=*/cstFalse, + /*dtype=*/cstNone); else - result = rewriter.create( - loc, resultType, result, /*dim=*/cstNone, /*keepdim=*/cstFalse, - /*dtype=*/cstNone); + result = + AtenSumDimIntListOp::create(rewriter, loc, resultType, result, + /*dim=*/cstNone, /*keepdim=*/cstFalse, + /*dtype=*/cstNone); rewriter.replaceOp(op, result); return success(); } @@ -9928,19 +10003,19 @@ class DecomposeAtenL1LossOp : public OpRewritePattern { auto subTy = outTy.getWithSizesAndDtype(selfTy.getSizes(), outDtype); Value sub = createTensorSub(rewriter, loc, subTy, self, target); - Value abs = rewriter.create(loc, subTy, sub); + Value abs = AtenAbsOp::create(rewriter, loc, subTy, sub); if (reductionInt == 0) { rewriter.replaceOp(op, abs); } else if (reductionInt == 1) { - Value none = rewriter.create(loc); - Value sum = rewriter.create(loc, outTy, abs, none); - Value numel = rewriter.create(loc, abs); - Value mean = rewriter.create(loc, outTy, sum, numel); + Value none = ConstantNoneOp::create(rewriter, loc); + Value sum = AtenSumOp::create(rewriter, loc, outTy, abs, none); + Value numel = AtenNumelOp::create(rewriter, loc, abs); + Value mean = AtenDivScalarOp::create(rewriter, loc, outTy, sum, numel); rewriter.replaceOp(op, mean); } else { - Value none = rewriter.create(loc); - Value sum = rewriter.create(loc, outTy, abs, none); + Value none = ConstantNoneOp::create(rewriter, loc); + Value sum = AtenSumOp::create(rewriter, loc, outTy, abs, none); rewriter.replaceOp(op, sum); } @@ -9958,11 +10033,11 @@ class DecomposeAtenNormScalarOptDimOp LogicalResult matchAndRewrite(AtenNormScalarOptDimOp op, PatternRewriter &rewriter) const override { Location loc = op->getLoc(); - Value none = rewriter.create(loc); + Value none = Torch::ConstantNoneOp::create(rewriter, loc); Value ord = op.getP(); if (isa(ord.getType())) { - ord = rewriter.create( - loc, rewriter.getF64FloatAttr(2.0)); + ord = Torch::ConstantFloatOp::create(rewriter, loc, + rewriter.getF64FloatAttr(2.0)); } rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), ord, op.getDim(), op.getKeepdim(), @@ -9995,27 +10070,27 @@ class DecomposeAtenRandintLowOp : public OpRewritePattern { return rewriter.notifyMatchFailure( op, "unimplemented: high must be a constant integer"); - Value none = rewriter.create(loc); - Value cstFalse = rewriter.create(loc, false); - Value low = rewriter.create( - loc, rewriter.getF64FloatAttr((double)cstLow)); - Value high = rewriter.create( - loc, rewriter.getF64FloatAttr((double)cstHigh)); + Value none = ConstantNoneOp::create(rewriter, loc); + Value cstFalse = ConstantBoolOp::create(rewriter, loc, false); + Value low = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr((double)cstLow)); + Value high = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr((double)cstHigh)); BaseTensorType floatResultType = cast(resultTensorType.getWithSizesAndDtype( resultTensorType.getSizes(), rewriter.getF32Type())); - Value emptyTensor = rewriter.create( - loc, floatResultType, op.getSize(), /*dtype=*/none, + Value emptyTensor = AtenEmptyMemoryFormatOp::create( + rewriter, loc, floatResultType, op.getSize(), /*dtype=*/none, /*layout=*/op.getLayout(), /*device=*/op.getDevice(), /*pinMemory=*/op.getPinMemory(), /*memoryFormat=*/none); Value result = - rewriter.create(loc, floatResultType, emptyTensor, - /*from=*/low, - /*to=*/high, - /*generator=*/none); + AtenUniformOp::create(rewriter, loc, floatResultType, emptyTensor, + /*from=*/low, + /*to=*/high, + /*generator=*/none); rewriter.replaceOpWithNewOp( op, resultType, result, getDtypeIntValueForType(rewriter, loc, resultTensorType.getDtype()), @@ -10035,8 +10110,8 @@ class DecomposeAtenRandintOp : public OpRewritePattern { Location loc = op.getLoc(); Type resultType = op.getType(); - Value low = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); + Value low = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(0)); rewriter.replaceOpWithNewOp( op, resultType, low, op.getHigh(), op.getSize(), op.getDtype(), @@ -10057,13 +10132,14 @@ class DecomposeAtenVarMeanCorrectionOp LogicalResult matchAndRewrite(AtenVarMeanCorrectionOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Value noneVal = rewriter.create(loc); - Value var = rewriter.create( - loc, op.getType(0), op.getSelf(), op.getDim(), op.getCorrection(), - op.getKeepdim()); - Value mean = rewriter.create( - loc, op.getType(0), op.getSelf(), op.getDim(), op.getKeepdim(), - /*dtype=*/noneVal); + Value noneVal = ConstantNoneOp::create(rewriter, loc); + Value var = AtenVarCorrectionOp::create( + rewriter, loc, op.getType(0), op.getSelf(), op.getDim(), + op.getCorrection(), op.getKeepdim()); + Value mean = + AtenMeanDimOp::create(rewriter, loc, op.getType(0), op.getSelf(), + op.getDim(), op.getKeepdim(), + /*dtype=*/noneVal); rewriter.replaceOp(op, {var, mean}); return success(); } @@ -10079,8 +10155,8 @@ class DecomposePrimsConvertElementTypeOp LogicalResult matchAndRewrite(PrimsConvertElementTypeOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Value cstFalse = rewriter.create(loc, false); - Value cstNone = rewriter.create(loc); + Value cstFalse = Torch::ConstantBoolOp::create(rewriter, loc, false); + Value cstNone = Torch::ConstantNoneOp::create(rewriter, loc); rewriter.replaceOpWithNewOp( op, op.getType(), op.getA(), op.getDtype(), /*nonBlocking=*/cstFalse, /*copy=*/cstFalse, /*memoryFormat=*/cstNone); @@ -10099,7 +10175,8 @@ class DecomposePrimsVarOp : public OpRewritePattern { if (!isa(op.getOutputDtype().getType())) return rewriter.notifyMatchFailure( op, "Unimplemented non-None dtype for prims::var op"); - Value cstFalse = rewriter.create(op.getLoc(), false); + Value cstFalse = + Torch::ConstantBoolOp::create(rewriter, op.getLoc(), false); rewriter.replaceOpWithNewOp( op, op.getType(), op.getInp(), op.getDims(), op.getCorrection(), /*keepdim=*/cstFalse); @@ -10139,45 +10216,43 @@ class DecomposeAtenRandnGeneratorOp op, "could not determine dtype from the op."); } - Value none = rewriter.create(loc); - Value low = rewriter.create( - loc, rewriter.getF64FloatAttr((double)0.0)); - Value high = rewriter.create( - loc, rewriter.getF64FloatAttr((double)1.0)); - Value cstMinusTwo = rewriter.create( - loc, rewriter.getF64FloatAttr((double)-2.0)); - Value cstTwoPie = rewriter.create( - loc, rewriter.getF64FloatAttr((double)(2.0 * 3.14159))); + Value none = ConstantNoneOp::create(rewriter, loc); + Value low = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr((double)0.0)); + Value high = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr((double)1.0)); + Value cstMinusTwo = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr((double)-2.0)); + Value cstTwoPie = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr((double)(2.0 * 3.14159))); - Value emptyTensorA = rewriter.create( - loc, resultType, op.getSize(), /*dtype=*/*dtype, + Value emptyTensorA = AtenEmptyMemoryFormatOp::create( + rewriter, loc, resultType, op.getSize(), /*dtype=*/*dtype, /*layout=*/op.getLayout(), /*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(), /*memory_format=*/none); - Value emptyTensorB = rewriter.create( - loc, resultType, op.getSize(), /*dtype=*/*dtype, + Value emptyTensorB = AtenEmptyMemoryFormatOp::create( + rewriter, loc, resultType, op.getSize(), /*dtype=*/*dtype, /*layout=*/op.getLayout(), /*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(), /*memory_format=*/none); - Value uOne = - rewriter.create(loc, resultType, emptyTensorA, + Value uOne = AtenUniformOp::create(rewriter, loc, resultType, emptyTensorA, /*from=*/low, /*to=*/high, /*generator=*/op.getGenerator()); - Value uTwo = - rewriter.create(loc, resultType, emptyTensorB, + Value uTwo = AtenUniformOp::create(rewriter, loc, resultType, emptyTensorB, /*from=*/low, /*to=*/high, /*generator=*/op.getGenerator()); - Value logUOne = rewriter.create(loc, resultType, uOne); - Value minusTwoLogUOne = - rewriter.create(loc, resultType, logUOne, cstMinusTwo); - Value r = rewriter.create(loc, resultType, minusTwoLogUOne); + Value logUOne = AtenLogOp::create(rewriter, loc, resultType, uOne); + Value minusTwoLogUOne = AtenMulScalarOp::create(rewriter, loc, resultType, + logUOne, cstMinusTwo); + Value r = AtenSqrtOp::create(rewriter, loc, resultType, minusTwoLogUOne); Value theta = - rewriter.create(loc, resultType, uTwo, cstTwoPie); - Value cosTheta = rewriter.create(loc, resultType, theta); + AtenMulScalarOp::create(rewriter, loc, resultType, uTwo, cstTwoPie); + Value cosTheta = AtenCosOp::create(rewriter, loc, resultType, theta); rewriter.replaceOpWithNewOp(op, op.getType(), r, cosTheta); return success(); } @@ -10191,7 +10266,7 @@ class DecomposeAtenRandnOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenRandnOp op, PatternRewriter &rewriter) const override { - Value none = rewriter.create(op.getLoc()); + Value none = Torch::ConstantNoneOp::create(rewriter, op.getLoc()); rewriter.replaceOpWithNewOp( op, op.getType(), op.getSize(), /*generator=*/none, op.getDtype(), op.getLayout(), op.getDevice(), op.getPinMemory()); @@ -10221,11 +10296,11 @@ class DecomposeAtenRandnLikeOp : public OpRewritePattern { op, "unimplemented: only none, contiguous and preserve " "memory_format is supported"); } - Value none = rewriter.create(op.getLoc()); + Value none = Torch::ConstantNoneOp::create(rewriter, op.getLoc()); auto sizeListType = Torch::ListType::get(Torch::IntType::get(op.getContext())); Value sizeList = - rewriter.create(op.getLoc(), sizeListType, op.getSelf()); + AtenSizeOp::create(rewriter, op.getLoc(), sizeListType, op.getSelf()); rewriter.replaceOpWithNewOp( op, op.getType(), sizeList, /*generator=*/none, op.getDtype(), op.getLayout(), op.getDevice(), op.getPinMemory()); @@ -10248,18 +10323,18 @@ class DecomposeAtenRandOp : public OpRewritePattern { return rewriter.notifyMatchFailure( op, "expected result type to have a dtype"); } - Value noneVal = rewriter.create(loc); - Value low = rewriter.create( - loc, rewriter.getF64FloatAttr((double)0.0)); - Value high = rewriter.create( - loc, rewriter.getF64FloatAttr((double)1.0)); + Value noneVal = Torch::ConstantNoneOp::create(rewriter, loc); + Value low = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr((double)0.0)); + Value high = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr((double)1.0)); FailureOr dtype = getDtypeFromOp(rewriter, op); if (failed(dtype)) { return rewriter.notifyMatchFailure( op, "could not determine dtype from the op."); } - Value emptyTensor = rewriter.create( - loc, resultType, op.getSize(), /*dtype=*/*dtype, + Value emptyTensor = AtenEmptyMemoryFormatOp::create( + rewriter, loc, resultType, op.getSize(), /*dtype=*/*dtype, /*layout=*/op.getLayout(), /*device=*/op.getDevice(), /*pin_memory=*/op.getPinMemory(), /*memory_format=*/noneVal); @@ -10281,12 +10356,12 @@ class DecomposeAtenLinspaceOp : public OpRewritePattern { Location loc = op.getLoc(); MLIRContext *context = getContext(); - Value none = rewriter.create(loc); - Value falseVal = rewriter.create(loc, false); + Value none = ConstantNoneOp::create(rewriter, loc); + Value falseVal = ConstantBoolOp::create(rewriter, loc, false); Value zero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0)); Value one = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); Value addStart; int64_t steps; @@ -10298,57 +10373,58 @@ class DecomposeAtenLinspaceOp : public OpRewritePattern { getTensorTypeFromShapeValues({op.getSteps()}, fp32Type); if (matchPattern(op.getSteps(), m_TorchConstantInt(&steps)) && steps == 1) { // specically handle steps == 1 - Value arange = rewriter.create( - loc, arangeIntType, zero, op.getSteps(), /*dtype=*/none, + Value arange = AtenArangeStartOp::create( + rewriter, loc, arangeIntType, zero, op.getSteps(), /*dtype=*/none, op.getLayout(), op.getDevice(), op.getPinMemory()); if (isa(op.getEnd().getType()) || isa(op.getStart().getType())) { - addStart = rewriter.create(loc, arangeFp32Type, arange, - op.getStart(), one); + addStart = AtenAddScalarOp::create(rewriter, loc, arangeFp32Type, + arange, op.getStart(), one); } else { - addStart = rewriter.create(loc, arangeIntType, arange, - op.getStart(), one); + addStart = AtenAddScalarOp::create(rewriter, loc, arangeIntType, arange, + op.getStart(), one); } } else { // handle steps != 1 or dynamic steps - Value neOrNot = rewriter.create(loc, op.getSteps(), one); - rewriter.create( - loc, neOrNot, + Value neOrNot = AtenNeIntOp::create(rewriter, loc, op.getSteps(), one); + RuntimeAssertOp::create( + rewriter, loc, neOrNot, rewriter.getStringAttr("linspace's dynamic steps must not be 1")); // create arange: [0, ..., steps - 1] - Value arange = rewriter.create( - loc, arangeIntType, zero, op.getSteps(), /*dtype=*/none, + Value arange = AtenArangeStartOp::create( + rewriter, loc, arangeIntType, zero, op.getSteps(), /*dtype=*/none, op.getLayout(), op.getDevice(), op.getPinMemory()); // calculate (end - start) / (steps - 1) Value sub; if (isa(op.getEnd().getType()) || isa(op.getStart().getType())) { - sub = rewriter.create(loc, Torch::FloatType::get(context), - op.getEnd(), op.getStart()); + sub = AtenSubOp::create(rewriter, loc, Torch::FloatType::get(context), + op.getEnd(), op.getStart()); } else { - sub = rewriter.create(loc, op.getEnd(), op.getStart()); + sub = AtenSubIntOp::create(rewriter, loc, op.getEnd(), op.getStart()); } - Value div = rewriter.create( - loc, sub, rewriter.create(loc, op.getSteps(), one)); + Value div = AtenDivOp::create( + rewriter, loc, sub, + AtenSubIntOp::create(rewriter, loc, op.getSteps(), one)); // calculate [0, ..., steps - 1] * ((end - start) / (steps - 1)) + start Value mulScalar = - rewriter.create(loc, arangeFp32Type, arange, div); - addStart = rewriter.create( - loc, arangeFp32Type, mulScalar, op.getStart(), one); + AtenMulScalarOp::create(rewriter, loc, arangeFp32Type, arange, div); + addStart = AtenAddScalarOp::create(rewriter, loc, arangeFp32Type, + mulScalar, op.getStart(), one); } // to dtype Value result; if (!isa(op.getDtype().getType())) { - result = rewriter.create( - loc, op.getType(), addStart, op.getDtype(), - /*non_blocking=*/falseVal, - /*copy=*/falseVal, /*memory_format=*/none); + result = AtenToDtypeOp::create(rewriter, loc, op.getType(), addStart, + op.getDtype(), + /*non_blocking=*/falseVal, + /*copy=*/falseVal, /*memory_format=*/none); } else { - Value f32Type = rewriter.create( - loc, (int)torch_upstream::ScalarType::Float); - result = rewriter.create( - loc, op.getType(), addStart, f32Type, /*non_blocking=*/falseVal, - /*copy=*/falseVal, /*memory_format=*/none); + Value f32Type = ConstantIntOp::create( + rewriter, loc, (int)torch_upstream::ScalarType::Float); + result = AtenToDtypeOp::create(rewriter, loc, op.getType(), addStart, + f32Type, /*non_blocking=*/falseVal, + /*copy=*/falseVal, /*memory_format=*/none); } rewriter.replaceOp(op, result); return success(); @@ -10363,13 +10439,13 @@ class DecomposeAtenVarMeanOp : public OpRewritePattern { LogicalResult matchAndRewrite(AtenVarMeanOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Value falseVal = rewriter.create(loc, false); - Value noneVal = rewriter.create(loc); - Value var = rewriter.create(loc, op.getType(0), op.getSelf(), - /*dim=*/noneVal, op.getUnbiased(), - /*keepdim=*/falseVal); - Value mean = rewriter.create(loc, op.getType(0), op.getSelf(), - /*dtype=*/noneVal); + Value falseVal = ConstantBoolOp::create(rewriter, loc, false); + Value noneVal = ConstantNoneOp::create(rewriter, loc); + Value var = AtenVarDimOp::create(rewriter, loc, op.getType(0), op.getSelf(), + /*dim=*/noneVal, op.getUnbiased(), + /*keepdim=*/falseVal); + Value mean = AtenMeanOp::create(rewriter, loc, op.getType(0), op.getSelf(), + /*dtype=*/noneVal); rewriter.replaceOp(op, {var, mean}); return success(); } @@ -10415,7 +10491,7 @@ class DecomposeAtenEmptyStridedOp return rewriter.notifyMatchFailure( op, "Unable to determine if stride is default"); - Value noneVal = rewriter.create(op.getLoc()); + Value noneVal = ConstantNoneOp::create(rewriter, op.getLoc()); FailureOr dtype = getDtypeFromOp(rewriter, op); if (failed(dtype)) { @@ -10508,10 +10584,11 @@ class DecomposeAtenMovedimIntOp : public OpRewritePattern { computeDimsOrderForMoveDim(srcDimInt, dstDimInt, inputRank); SmallVector cstDimsOrder; for (int64_t dim : dimsOrder) - cstDimsOrder.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(dim))); - Value permuteDimsOrder = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(op->getContext())), + cstDimsOrder.push_back(Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(dim))); + Value permuteDimsOrder = PrimListConstructOp::create( + rewriter, loc, + Torch::ListType::get(Torch::IntType::get(op->getContext())), cstDimsOrder); rewriter.replaceOpWithNewOp(op, op.getType(), input, permuteDimsOrder); @@ -10571,11 +10648,11 @@ class DecomposeAtenCrossEntropyLossOp "value of 0.0 for label_smoothing"); } - Value noneVal = rewriter.create(loc); - Value dim = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - Value logSoftmax = rewriter.create( - loc, self.getType(), self, dim, /*dtype=*/noneVal); + Value noneVal = ConstantNoneOp::create(rewriter, loc); + Value dim = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(1)); + Value logSoftmax = AtenLogSoftmaxIntOp::create( + rewriter, loc, self.getType(), self, dim, /*dtype=*/noneVal); Type secondType; if (reductionInt == 0) { @@ -10586,10 +10663,9 @@ class DecomposeAtenCrossEntropyLossOp } Value nllLoss = - rewriter - .create( - loc, op.getType(), secondType, logSoftmax, target, - op.getWeight(), op.getReduction(), op.getIgnoreIndex()) + AtenNllLossForwardOp::create(rewriter, loc, op.getType(), secondType, + logSoftmax, target, op.getWeight(), + op.getReduction(), op.getIgnoreIndex()) ->getResult(0); rewriter.replaceOp(op, nllLoss); return success(); @@ -10685,8 +10761,8 @@ class DecomposeAtenNllLossForwardOp if (selfRank < 2) { channelDim = 0; } - Value channelDimValue = rewriter.create( - loc, rewriter.getI64IntegerAttr(channelDim)); + Value channelDimValue = ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(channelDim)); auto ignoreIndex = op.getIgnoreIndex(); Value w; @@ -10698,37 +10774,37 @@ class DecomposeAtenNllLossForwardOp newShapeList[channelDim] = weightSizes[0]; SmallVector newShapeListValue; for (size_t i = 0; i < newShapeList.size(); ++i) { - newShapeListValue.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(newShapeList[i]))); + newShapeListValue.push_back(ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(newShapeList[i]))); } - Value newShape = rewriter.create( - loc, - rewriter.getType( - rewriter.getType()), - newShapeListValue); + Value newShape = + PrimListConstructOp::create(rewriter, loc, + rewriter.getType( + rewriter.getType()), + newShapeListValue); auto newType = weightType.getWithSizesAndDtype(newShapeList, weightType.getDtype()); - w = rewriter.create(loc, newType, weight, newShape); + w = AtenViewOp::create(rewriter, loc, newType, weight, newShape); } else { w = weight; } - self = rewriter.create(loc, self.getType(), self, w); + self = AtenMulTensorOp::create(rewriter, loc, self.getType(), self, w); } SmallVector targetDimSizes(targetSizes); Value zero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0)); auto condType = ValueTensorType::get(ctx, targetDimSizes, rewriter.getI1Type()); auto unequalCond = - rewriter.create(loc, condType, target, ignoreIndex); + AtenNeScalarOp::create(rewriter, loc, condType, target, ignoreIndex); auto zeroTensorType = ValueTensorType::get(ctx, {}, rewriter.getIntegerType(64, true)); Value zeroTensor = - rewriter.create(loc, zeroTensorType, zero); - auto safeTarget = rewriter.create( - loc, target.getType(), unequalCond, target, zeroTensor); + PrimNumToTensorScalarOp::create(rewriter, loc, zeroTensorType, zero); + auto safeTarget = AtenWhereSelfOp::create(rewriter, loc, target.getType(), + unequalCond, target, zeroTensor); SmallVector safeTargetShape; for (size_t i = 0; i < targetSizes.size(); ++i) { @@ -10743,37 +10819,40 @@ class DecomposeAtenNllLossForwardOp auto gatherType = ValueTensorType::get(ctx, safeTargetShape, targetType.getDtype()); - auto safeTarget_ = rewriter.create( - loc, gatherType, safeTarget, channelDimValue); + auto safeTarget_ = AtenUnsqueezeOp::create(rewriter, loc, gatherType, + safeTarget, channelDimValue); auto falseValue = - rewriter.create(loc, rewriter.getBoolAttr(false)); - auto none = rewriter.create(loc); - auto _gather = rewriter.create( - loc, ValueTensorType::get(ctx, safeTargetShape, selfType.getDtype()), - self, channelDimValue, safeTarget_, falseValue); - Value gather = rewriter.create(loc, _gather.getType(), _gather); + ConstantBoolOp::create(rewriter, loc, rewriter.getBoolAttr(false)); + auto none = ConstantNoneOp::create(rewriter, loc); + auto _gather = AtenGatherOp::create( + rewriter, loc, + ValueTensorType::get(ctx, safeTargetShape, selfType.getDtype()), self, + channelDimValue, safeTarget_, falseValue); + Value gather = AtenNegOp::create(rewriter, loc, _gather.getType(), _gather); auto unequalCondType = cast(unequalCond.getType()); - auto result = rewriter.create( - loc, + auto result = AtenWhereSelfOp::create( + rewriter, loc, unequalCondType.getWithSizesAndDtype(unequalCondType.getSizes(), selfType.getDtype()), unequalCond, - rewriter.create( - loc, ValueTensorType::get(ctx, targetSizes, selfType.getDtype()), - gather, channelDimValue), + AtenSqueezeDimOp::create( + rewriter, loc, + ValueTensorType::get(ctx, targetSizes, selfType.getDtype()), gather, + channelDimValue), zeroTensor); Value totalWeight; if (reduction == 0 && selfRank > 1) { auto zeroFloat = - rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); - Value twSize = rewriter.create( - loc, + ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(0.0)); + Value twSize = PrimListConstructOp::create( + rewriter, loc, rewriter.getType(rewriter.getType()), ValueRange({})); - totalWeight = rewriter.create( - loc, op.getType(1), self, twSize, zeroFloat, none, none, none, none); + totalWeight = + AtenNewFullOp::create(rewriter, loc, op.getType(1), self, twSize, + zeroFloat, none, none, none, none); rewriter.replaceOp(op, {result, totalWeight}); return success(); @@ -10784,37 +10863,39 @@ class DecomposeAtenNllLossForwardOp auto newWType = wType.getWithSizesAndDtype(selfSizes, wType.getDtype()); SmallVector selfSizesValue; for (size_t i = 0; i < selfSizes.size(); ++i) { - selfSizesValue.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(selfSizes[i]))); + selfSizesValue.push_back(ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(selfSizes[i]))); } - auto wSize = rewriter.create( - loc, + auto wSize = PrimListConstructOp::create( + rewriter, loc, rewriter.getType(rewriter.getType()), selfSizesValue); - w = rewriter.create(loc, newWType, w, wSize, falseValue); - auto wSumGather = rewriter.create( - loc, ValueTensorType::get(ctx, safeTargetShape, wType.getDtype()), w, + w = AtenExpandOp::create(rewriter, loc, newWType, w, wSize, falseValue); + auto wSumGather = AtenGatherOp::create( + rewriter, loc, + ValueTensorType::get(ctx, safeTargetShape, wType.getDtype()), w, channelDimValue, safeTarget_, falseValue); - auto wSumSq = rewriter.create( - loc, ValueTensorType::get(ctx, targetSizes, wType.getDtype()), - wSumGather, channelDimValue); - auto wSum = rewriter.create( - loc, + auto wSumSq = AtenSqueezeDimOp::create( + rewriter, loc, + ValueTensorType::get(ctx, targetSizes, wType.getDtype()), wSumGather, + channelDimValue); + auto wSum = AtenWhereSelfOp::create( + rewriter, loc, ValueTensorType::get(ctx, unequalCondType.getSizes(), wType.getDtype()), unequalCond, wSumSq, zeroTensor); - totalWeight = rewriter.create(loc, op.getType(1), wSum, none); + totalWeight = AtenSumOp::create(rewriter, loc, op.getType(1), wSum, none); } else { totalWeight = - rewriter.create(loc, op.getType(1), unequalCond, none); + AtenSumOp::create(rewriter, loc, op.getType(1), unequalCond, none); } auto resultSum = - rewriter.create(loc, op.getType(0), result, none); + AtenSumOp::create(rewriter, loc, op.getType(0), result, none); if (reduction == 1) { - auto resultMean = rewriter.create( - loc, op.getType(0), resultSum, totalWeight); + auto resultMean = AtenDivTensorOp::create(rewriter, loc, op.getType(0), + resultSum, totalWeight); rewriter.replaceOp(op, {resultMean, totalWeight}); return success(); @@ -10863,38 +10944,38 @@ class DecomposeAtenPoissonNllLossOp op, "Unimplemented: full loss computation is not supported"); Value one = - rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); - Value epsConst = rewriter.create( - loc, rewriter.getF64FloatAttr(epsFloat)); + ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(1.0)); + Value epsConst = ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(epsFloat)); - Value safeInput = rewriter.create(loc, input.getType(), - input, epsConst, one); + Value safeInput = AtenAddScalarOp::create(rewriter, loc, input.getType(), + input, epsConst, one); Value loss; if (logInVal) { - Value expIn = rewriter.create(loc, input.getType(), input); - Value targetMulInput = - rewriter.create(loc, input.getType(), target, input); - loss = rewriter.create(loc, input.getType(), expIn, - targetMulInput, one); + Value expIn = AtenExpOp::create(rewriter, loc, input.getType(), input); + Value targetMulInput = AtenMulTensorOp::create( + rewriter, loc, input.getType(), target, input); + loss = AtenSubTensorOp::create(rewriter, loc, input.getType(), expIn, + targetMulInput, one); } else { Value logSafeInput = - rewriter.create(loc, input.getType(), safeInput); - Value targetMulLog = rewriter.create( - loc, input.getType(), target, logSafeInput); - loss = rewriter.create(loc, input.getType(), input, - targetMulLog, one); + AtenLogOp::create(rewriter, loc, input.getType(), safeInput); + Value targetMulLog = AtenMulTensorOp::create( + rewriter, loc, input.getType(), target, logSafeInput); + loss = AtenSubTensorOp::create(rewriter, loc, input.getType(), input, + targetMulLog, one); } Value result = loss; if (reductionInt == 1) { // Case 1: Mean Reduction - result = rewriter.create( - loc, op.getType(), loss, rewriter.create(loc)); + result = AtenMeanOp::create(rewriter, loc, op.getType(), loss, + ConstantNoneOp::create(rewriter, loc)); } else if (reductionInt == 2) { // Case 2: Sum Reduction - result = rewriter.create(loc, op.getType(), loss, - rewriter.create(loc)); + result = AtenSumOp::create(rewriter, loc, op.getType(), loss, + ConstantNoneOp::create(rewriter, loc)); } rewriter.replaceOp(op, result); return success(); @@ -10936,22 +11017,22 @@ class DecomposeAtenKlDivOp : public OpRewritePattern { // Default: target tensor is not in log space Value logOfTarget; if (!logTargetBool) { - logOfTarget = rewriter.create(loc, targetTy, target); + logOfTarget = AtenLogOp::create(rewriter, loc, targetTy, target); } else { logOfTarget = target; } Value constOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); - Value subValue = rewriter.create(loc, selfTy, logOfTarget, - self, constOne); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); + Value subValue = AtenSubTensorOp::create(rewriter, loc, selfTy, logOfTarget, + self, constOne); // if target tensor is already in log space if (logTargetBool) { - target = rewriter.create(loc, targetTy, target); + target = AtenExpOp::create(rewriter, loc, targetTy, target); } Value lossPointwise = - rewriter.create(loc, targetTy, target, subValue); + AtenMulTensorOp::create(rewriter, loc, targetTy, target, subValue); // Extract reduction int value from reduction argument int64_t reduction; @@ -10961,13 +11042,13 @@ class DecomposeAtenKlDivOp : public OpRewritePattern { } Value loss; - Value none = rewriter.create(loc); + Value none = ConstantNoneOp::create(rewriter, loc); // reduction: mean if (reduction == 1) { - loss = rewriter.create(loc, outTy, lossPointwise, none); + loss = AtenMeanOp::create(rewriter, loc, outTy, lossPointwise, none); } else if (reduction == 2) { // reduction: sum - loss = rewriter.create(loc, outTy, lossPointwise, none); + loss = AtenSumOp::create(rewriter, loc, outTy, lossPointwise, none); } else { // reduction: none loss = lossPointwise; @@ -10995,38 +11076,38 @@ class DecomposeAtenBinaryCrossEntropyWithLogitsOp Value loss; auto one = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); auto _one = - rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(-1)); auto _target = - rewriter.create(loc, target.getType(), target, _one); - auto _target_1 = rewriter.create(loc, _target.getType(), - _target, one, one); + AtenMulScalarOp::create(rewriter, loc, target.getType(), target, _one); + auto _target_1 = AtenAddScalarOp::create(rewriter, loc, _target.getType(), + _target, one, one); Value mm = - rewriter.create(loc, self.getType(), _target_1, self); + AtenMulTensorOp::create(rewriter, loc, self.getType(), _target_1, self); Value logSigm = - rewriter.create(loc, self.getType(), self); + AtenLogSigmoidOp::create(rewriter, loc, self.getType(), self); if (!isa(posWeight.getType())) { - auto logWeight = rewriter.create( - loc, posWeight.getType(), - rewriter.create(loc, posWeight.getType(), posWeight, - one, one), + auto logWeight = AtenAddScalarOp::create( + rewriter, loc, posWeight.getType(), + AtenSubScalarOp::create(rewriter, loc, posWeight.getType(), posWeight, + one, one), one, one); - loss = rewriter.create( - loc, mm.getType(), mm, - rewriter.create(loc, logWeight.getType(), logWeight, - logSigm), + loss = AtenSubTensorOp::create( + rewriter, loc, mm.getType(), mm, + AtenMulTensorOp::create(rewriter, loc, logWeight.getType(), logWeight, + logSigm), one); } else { - loss = - rewriter.create(loc, mm.getType(), mm, logSigm, one); + loss = AtenSubTensorOp::create(rewriter, loc, mm.getType(), mm, logSigm, + one); } if (!isa(weight.getType())) { loss = - rewriter.create(loc, loss.getType(), loss, weight); + AtenMulTensorOp::create(rewriter, loc, loss.getType(), loss, weight); } // apply loss reduction. @@ -11035,12 +11116,12 @@ class DecomposeAtenBinaryCrossEntropyWithLogitsOp return rewriter.notifyMatchFailure(op, "no reduction type is appointed!"); } - auto none = rewriter.create(loc); + auto none = ConstantNoneOp::create(rewriter, loc); Value res; if (reductionInt == 1) { - res = rewriter.create(loc, op.getType(), loss, none); + res = AtenMeanOp::create(rewriter, loc, op.getType(), loss, none); } else if (reductionInt == 2) { - res = rewriter.create(loc, op.getType(), loss, none); + res = AtenSumOp::create(rewriter, loc, op.getType(), loss, none); } else { res = loss; } @@ -11065,9 +11146,9 @@ class DecomposeAtenExp2Op : public OpRewritePattern { } auto two = - rewriter.create(loc, rewriter.getI64IntegerAttr(2)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(2)); Value to = convertTensorToDtype(rewriter, loc, self, resultTy.getDtype()); - Value pow = rewriter.create(loc, resultTy, two, to); + Value pow = AtenPowScalarOp::create(rewriter, loc, resultTy, two, to); rewriter.replaceOp(op, pow); return success(); } @@ -11092,19 +11173,20 @@ class DecomposeAtenOneHotOp : public OpRewritePattern { int64_t numClasses = Torch::kUnknownSize; auto resultType = cast(op.getType()); matchPattern(op.getNumClasses(), m_TorchConstantInt(&numClasses)); - Value none = rewriter.create(loc); + Value none = ConstantNoneOp::create(rewriter, loc); // arange tensor auto si64Type = IntegerType::get(context, 64, IntegerType::Signed); auto arangeType = ValueTensorType::get(context, llvm::ArrayRef(numClasses), si64Type); - Value arangeTensor = rewriter.create( - loc, arangeType, op.getNumClasses(), /*dtype=*/none, /*layout=*/none, - /*device=*/none, /*pin_memory=*/none); + Value arangeTensor = + AtenArangeOp::create(rewriter, loc, arangeType, op.getNumClasses(), + /*dtype=*/none, /*layout=*/none, + /*device=*/none, /*pin_memory=*/none); // unsqueeze input - Value rankV = rewriter.create( - loc, rewriter.getI64IntegerAttr(inputRank)); + Value rankV = ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(inputRank)); auto unsqueeze = Torch::unsqueezeTensor(rewriter, op, input, rankV); if (failed(unsqueeze)) return rewriter.notifyMatchFailure(op, @@ -11117,8 +11199,8 @@ class DecomposeAtenOneHotOp : public OpRewritePattern { auto eqType = ValueTensorType::get( context, cast(op.getType()).getSizes(), IntegerType::get(context, 1)); - Value eqTensor = rewriter.create( - loc, eqType, unsqueezeTensor, arangeTensor); + Value eqTensor = AtenEqTensorOp::create(rewriter, loc, eqType, + unsqueezeTensor, arangeTensor); // convert to si64 Value result = @@ -11139,13 +11221,14 @@ class DecomposeAtenVarMeanDimOp : public OpRewritePattern { LogicalResult matchAndRewrite(AtenVarMeanDimOp op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Value noneVal = rewriter.create(loc); - Value var = rewriter.create(loc, op.getType(0), op.getSelf(), - op.getDim(), op.getUnbiased(), - op.getKeepdim()); - Value mean = rewriter.create( - loc, op.getType(0), op.getSelf(), op.getDim(), op.getKeepdim(), - /*dtype=*/noneVal); + Value noneVal = ConstantNoneOp::create(rewriter, loc); + Value var = + AtenVarDimOp::create(rewriter, loc, op.getType(0), op.getSelf(), + op.getDim(), op.getUnbiased(), op.getKeepdim()); + Value mean = + AtenMeanDimOp::create(rewriter, loc, op.getType(0), op.getSelf(), + op.getDim(), op.getKeepdim(), + /*dtype=*/noneVal); rewriter.replaceOp(op, {var, mean}); return success(); } @@ -11163,17 +11246,18 @@ class DecomposeAtenScalarTensor : public OpRewritePattern { auto resultTy = cast(op.getResult().getType()); auto scalarTy = getBuiltInTypeForTorchScalar(op.getS().getType()); - Value numToTensor = rewriter.create( - op.getLoc(), + Value numToTensor = PrimNumToTensorScalarOp::create( + rewriter, op.getLoc(), resultTy.getWithSizesAndDtype(resultTy.getOptionalSizes(), scalarTy), op.getS()); - Value cstNone = rewriter.create(op.getLoc()); - Value cstFalse = rewriter.create(op.getLoc(), false); + Value cstNone = ConstantNoneOp::create(rewriter, op.getLoc()); + Value cstFalse = + Torch::ConstantBoolOp::create(rewriter, op.getLoc(), false); Value dtype = getDtypeIntValueForType(rewriter, op.getLoc(), resultTy.getDtype()); - Value toDTypeLayout = rewriter.create( - op.getLoc(), op.getType(), numToTensor, dtype, op.getLayout(), + Value toDTypeLayout = AtenToDtypeLayoutOp::create( + rewriter, op.getLoc(), op.getType(), numToTensor, dtype, op.getLayout(), op.getDevice(), op.getPinMemory(), /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/cstNone); rewriter.replaceOp(op, toDTypeLayout); @@ -11206,19 +11290,21 @@ class DecomposeAtenTopkOp : public OpRewritePattern { auto sortIndicesType = selfType.getWithSizesAndDtype( selfType.getOptionalSizes(), IntegerType::get(context, 64, IntegerType::Signed)); - auto sortOpResult = rewriter.create( - loc, self.getType(), sortIndicesType, self, dim, - /*descending=*/op.getLargest()); - Value start = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - Value step = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - Value resultValue = rewriter.create( - loc, op->getResultTypes()[0], sortOpResult->getResult(0), dim, start, - /*end=*/op.getK(), step); - Value resultIndices = rewriter.create( - loc, op->getResultTypes()[1], sortOpResult->getResult(1), dim, start, - /*end=*/op.getK(), step); + auto sortOpResult = AtenSortOp::create(rewriter, loc, self.getType(), + sortIndicesType, self, dim, + /*descending=*/op.getLargest()); + Value start = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(0)); + Value step = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(1)); + Value resultValue = + AtenSliceTensorOp::create(rewriter, loc, op->getResultTypes()[0], + sortOpResult->getResult(0), dim, start, + /*end=*/op.getK(), step); + Value resultIndices = + AtenSliceTensorOp::create(rewriter, loc, op->getResultTypes()[1], + sortOpResult->getResult(1), dim, start, + /*end=*/op.getK(), step); rewriter.replaceOp(op, {resultValue, resultIndices}); return success(); } @@ -11242,8 +11328,8 @@ class DecomposeAtenArgsortOp : public OpRewritePattern { auto sortIndicesType = selfType.getWithSizesAndDtype( selfType.getOptionalSizes(), IntegerType::get(context, 64, IntegerType::Signed)); - auto sortOpResult = rewriter.create( - loc, self.getType(), sortIndicesType, self, dim, descending); + auto sortOpResult = AtenSortOp::create( + rewriter, loc, self.getType(), sortIndicesType, self, dim, descending); rewriter.replaceOp(op, sortOpResult->getResult(1)); return success(); } @@ -11272,8 +11358,8 @@ Value getDFTMatmulCoeff(PatternRewriter &rewriter, Location loc, } } - return rewriter.create( - loc, matrixType, + return ValueTensorLiteralOp::create( + rewriter, loc, matrixType, DenseElementsAttr::get(matrixType.toBuiltinTensor(), ArrayRef(values))); } @@ -11330,12 +11416,12 @@ class DecomposeAtenFftRfftOp final : public OpRewritePattern { if (failed(getTransposedType(cast(input.getType()), dimA, dimB, transposedType))) return failure(); - Value cstDimA = - rewriter.create(loc, rewriter.getI64IntegerAttr(dimA)); - Value cstDimB = - rewriter.create(loc, rewriter.getI64IntegerAttr(dimB)); - transposed = rewriter.create(loc, transposedType, - input, cstDimA, cstDimB); + Value cstDimA = ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(dimA)); + Value cstDimB = ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(dimB)); + transposed = AtenTransposeIntOp::create(rewriter, loc, transposedType, + input, cstDimA, cstDimB); return success(); }; @@ -11363,7 +11449,7 @@ class DecomposeAtenFftRfftOp final : public OpRewritePattern { ValueTensorType matmulType = ValueTensorType::get(op.getContext(), matmulShape, dtype); Value flatRes = - rewriter.create(loc, matmulType, self, coeffMatrix); + AtenMatmulOp::create(rewriter, loc, matmulType, self, coeffMatrix); // Y = unflatten(X, -1, [outputFftDim, 2]) // : (D_0 x ... x D_m x outputFftDim x 2) @@ -11375,14 +11461,15 @@ class DecomposeAtenFftRfftOp final : public OpRewritePattern { Type unflattenedResType = ValueTensorType::get(op.getContext(), unflattenedResShape, dtype); Value cstMinusOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(-1)); Value unflattenSizes = toIntListConstruct(rewriter, loc, {outputFftDim, 2}); - Value unflattenedRes = rewriter.create( - loc, unflattenedResType, flatRes, /*dim=*/cstMinusOne, unflattenSizes); + Value unflattenedRes = + AtenUnflattenIntOp::create(rewriter, loc, unflattenedResType, flatRes, + /*dim=*/cstMinusOne, unflattenSizes); Type complexResType = ValueTensorType::get(op.getContext(), complexResShape, ComplexType::get(dtype)); - Value complexRes = rewriter.create(loc, complexResType, - unflattenedRes); + Value complexRes = AtenViewAsComplexOp::create( + rewriter, loc, complexResType, unflattenedRes); // Transpose back if (needTranspose) { @@ -11431,32 +11518,32 @@ class DecomposeAtenHannWindowPeriodicOp if (window_length == 1) { Value one = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); SmallVector sizes({one}); - Value sizeList = rewriter.create( - loc, ListType::get(IntType::get(context)), sizes); + Value sizeList = PrimListConstructOp::create( + rewriter, loc, ListType::get(IntType::get(context)), sizes); rewriter.replaceOpWithNewOp(op, opType, sizeList, opDtype, opLayout, opDevice, opPinMemory); return success(); } Value zero = - rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); + ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(0.0)); - Value arange = rewriter.create( - loc, opType, zero, op.getWindowLength(), opDtype, opLayout, opDevice, - opPinMemory); + Value arange = AtenArangeStartOp::create(rewriter, loc, opType, zero, + op.getWindowLength(), opDtype, + opLayout, opDevice, opPinMemory); double denominator = !periodic ? window_length - 1 : window_length; double piOverDenominator = 3.14159 / denominator; - Value cstFactor = rewriter.create( - loc, rewriter.getF64FloatAttr(piOverDenominator)); + Value cstFactor = ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(piOverDenominator)); Value fraction = - rewriter.create(loc, opType, arange, cstFactor); - Value sine = rewriter.create(loc, opType, fraction); + AtenMulScalarOp::create(rewriter, loc, opType, arange, cstFactor); + Value sine = AtenSinOp::create(rewriter, loc, opType, fraction); rewriter.replaceOpWithNewOp(op, opType, sine); @@ -11487,11 +11574,11 @@ class DecomposeAtenScatterValueOp SmallVector sizes; for (int64_t i = 0; i < indexRank; ++i) { Value dim = - rewriter.create(loc, rewriter.getI64IntegerAttr(i)); - sizes.push_back(rewriter.create(loc, index, /*dim=*/dim)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(i)); + sizes.push_back(AtenSizeIntOp::create(rewriter, loc, index, /*dim=*/dim)); } - Value sizeList = rewriter.create( - loc, ListType::get(IntType::get(context)), sizes); + Value sizeList = PrimListConstructOp::create( + rewriter, loc, ListType::get(IntType::get(context)), sizes); auto selfType = cast(self.getType()); auto indexType = cast(index.getType()); @@ -11513,7 +11600,8 @@ class DecomposePrimsSumOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(PrimsSumOp op, PatternRewriter &rewriter) const override { - Value cstFalse = rewriter.create(op.getLoc(), false); + Value cstFalse = + Torch::ConstantBoolOp::create(rewriter, op.getLoc(), false); rewriter.replaceOpWithNewOp( op, op.getType(), op.getInp(), op.getDims(), /*keepdim=*/cstFalse, @@ -11543,19 +11631,19 @@ class DecomposeAtenSgnOp : public OpRewritePattern { } auto zero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0)); auto one = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); auto minusOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(-1)); auto compTy = outType.getWithSizesAndDtype(outType.getOptionalSizes(), rewriter.getI1Type()); auto greater = - rewriter.create(loc, compTy, op.getSelf(), zero); + AtenGtScalarOp::create(rewriter, loc, compTy, op.getSelf(), zero); auto less = - rewriter.create(loc, compTy, op.getSelf(), zero); + AtenLtScalarOp::create(rewriter, loc, compTy, op.getSelf(), zero); // Pseudo code: // if (in > 0) @@ -11568,7 +11656,7 @@ class DecomposeAtenSgnOp : public OpRewritePattern { // return 1 if inf // return -1 if -inf auto selectGreater = - rewriter.create(loc, outType, greater, one, zero); + AtenWhereScalarOp::create(rewriter, loc, outType, greater, one, zero); rewriter.replaceOpWithNewOp(op, outType, less, minusOne, selectGreater); @@ -11609,35 +11697,35 @@ class DecomposeAtenHeaviside : public OpRewritePattern { op.getContext(), llvm::ArrayRef(broadcastShape), resultTy.getDtype()); auto boolBroadcastType = ValueTensorType::get( op.getContext(), llvm::ArrayRef(broadcastShape), rewriter.getI1Type()); - Value indexBroadcastShapeTorchList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), + Value indexBroadcastShapeTorchList = PrimListConstructOp::create( + rewriter, loc, + Torch::ListType::get(Torch::IntType::get(op.getContext())), broadcastShapeValue); - auto inputBroadcasted = rewriter.create( - loc, broadcastType, input, indexBroadcastShapeTorchList); - auto valueBroadcasted = rewriter.create( - loc, broadcastType, value, indexBroadcastShapeTorchList); + auto inputBroadcasted = AtenBroadcastToOp::create( + rewriter, loc, broadcastType, input, indexBroadcastShapeTorchList); + auto valueBroadcasted = AtenBroadcastToOp::create( + rewriter, loc, broadcastType, value, indexBroadcastShapeTorchList); Value zero = getConstantWithGivenDtypeAndValue(rewriter, loc, 0, resultTy.getDtype()); Value one = getConstantWithGivenDtypeAndValue(rewriter, loc, 1, resultTy.getDtype()); // Compute mask: input == 0 - auto inputEqZero = rewriter - .create(loc, boolBroadcastType, - inputBroadcasted, zero) + auto inputEqZero = AtenEqScalarOp::create(rewriter, loc, boolBroadcastType, + inputBroadcasted, zero) ->getResult(0); // Compute mask: input < 0 - auto inputLtZero = rewriter.create(loc, boolBroadcastType, - inputBroadcasted, zero); + auto inputLtZero = AtenLtScalarOp::create(rewriter, loc, boolBroadcastType, + inputBroadcasted, zero); // Compute mask: isnan(input) auto isNan = - rewriter.create(loc, boolBroadcastType, inputBroadcasted); + AtenIsnanOp::create(rewriter, loc, boolBroadcastType, inputBroadcasted); // Combine: input < 0 || isnan(input) - auto inputNegativeOrNan = rewriter.create( - loc, boolBroadcastType, inputLtZero, isNan); + auto inputNegativeOrNan = AtenLogicalOrOp::create( + rewriter, loc, boolBroadcastType, inputLtZero, isNan); // Select 0 if input < 0 or input is nan, else 1 - auto zerosOrOnes = rewriter.create( - loc, resultTy, inputNegativeOrNan, zero, one); + auto zerosOrOnes = AtenWhereScalarOp::create(rewriter, loc, resultTy, + inputNegativeOrNan, zero, one); // Final result: if input == 0, take from valueBroadcasted, else take from // zerosOrOnes rewriter.replaceOpWithNewOp(op, resultTy, inputEqZero, @@ -11659,10 +11747,10 @@ class DecomposeAtenTypeAsOp : public OpRewritePattern { auto other = op.getOther(); Location loc = op.getLoc(); - Value targetDtype = rewriter.create(loc, other); - Value nonBlocking = rewriter.create(loc, false); - Value copy = rewriter.create(loc, false); - Value memoryFormat = rewriter.create(loc); + Value targetDtype = Torch::PrimDtypeOp::create(rewriter, loc, other); + Value nonBlocking = Torch::ConstantBoolOp::create(rewriter, loc, false); + Value copy = Torch::ConstantBoolOp::create(rewriter, loc, false); + Value memoryFormat = Torch::ConstantNoneOp::create(rewriter, loc); rewriter.replaceOpWithNewOp( op, op.getType(), input, targetDtype, nonBlocking, copy, memoryFormat); return success(); @@ -11679,8 +11767,8 @@ static FailureOr unsqueezeTensorAtTrailingDim(Operation *op, PatternRewriter &rewriter, Value input, int count) { Location loc = op->getLoc(); - Value constMinusOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(-1)); + Value constMinusOne = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(-1)); Value result = input; while (count--) { auto unsqzTensorInfo = @@ -11699,19 +11787,19 @@ static Value createIndexToReplaceNone(Operation *op, PatternRewriter &rewriter, int64_t dimSize) { Location loc = op->getLoc(); MLIRContext *context = op->getContext(); - Value none = rewriter.create(loc); + Value none = Torch::ConstantNoneOp::create(rewriter, loc); auto int64Dtype = getDtypeIntValueForType( rewriter, loc, rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true)); auto resultType = ValueTensorType::get( context, {dimSize}, rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true)); - auto dim = rewriter.create( - loc, rewriter.getI64IntegerAttr(dimInt)); - auto end = rewriter.create(loc, input, dim); - auto v = rewriter.create( - loc, resultType, /*end=*/end, /*dtype=*/int64Dtype, /*layout=*/none, - /*device=*/none, /*pin_memory=*/none); + auto dim = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(dimInt)); + auto end = Torch::AtenSizeIntOp::create(rewriter, loc, input, dim); + auto v = Torch::AtenArangeOp::create(rewriter, loc, resultType, /*end=*/end, + /*dtype=*/int64Dtype, /*layout=*/none, + /*device=*/none, /*pin_memory=*/none); return v; } @@ -11789,8 +11877,8 @@ static FailureOr createNewIndices(Operation *op, } auto listElemType = ValueTensorType::get(context, std::nullopt, nullptr); - Value newIndexList = rewriter.create( - loc, Torch::ListType::get(listElemType), listElements); + Value newIndexList = Torch::PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(listElemType), listElements); return newIndexList; } @@ -11829,8 +11917,8 @@ class DecomposeAtenIndexTensorOp : public OpRewritePattern { // By default, we regard the first index type as the list element type. auto indexElemType = cast(indices[0].getType()) .getWithSizesAndDtype(std::nullopt, nullptr); - auto newIndices = rewriter.create( - loc, Torch::ListType::get(indexElemType), indices); + auto newIndices = PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(indexElemType), indices); rewriter.replaceOpWithNewOp( op, op.getType(), input, newIndices); return success(); @@ -11861,23 +11949,24 @@ class DecomposeAtenIndexTensorOp : public OpRewritePattern { for (int i = 0; i < inputRank; i++) { if (indexUsed[i]) { newToOldDimMap.emplace_back(i); - dimValues.emplace_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(i))); + dimValues.emplace_back(Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(i))); permutedSizes.emplace_back(inputSizes[i]); } } for (int i = 0; i < inputRank; i++) { if (!indexUsed[i]) { newToOldDimMap.emplace_back(i); - dimValues.emplace_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(i))); + dimValues.emplace_back(Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(i))); permutedSizes.emplace_back(inputSizes[i]); } } - auto dimValueList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(context)), dimValues); - newInput = rewriter.create( - loc, + auto dimValueList = Torch::PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(Torch::IntType::get(context)), + dimValues); + newInput = Torch::AtenPermuteOp::create( + rewriter, loc, inputType.getWithSizesAndDtype(permutedSizes, inputType.getOptionalDtype()), input, dimValueList); @@ -11934,8 +12023,8 @@ class DecomposeAtenIndexPutLikeOp // By default, we regard the first index type as the list element type. auto indexElemType = cast(indices[0].getType()) .getWithSizesAndDtype(std::nullopt, nullptr); - auto newIndex = rewriter.create( - loc, Torch::ListType::get(indexElemType), indices); + auto newIndex = PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(indexElemType), indices); rewriter.replaceOpWithNewOp( op, op.getType(), input, newIndex, op.getValues(), op.getAccumulate()); @@ -12003,13 +12092,13 @@ class DecomposeAtenTileOp : public OpRewritePattern { } auto inputRank = inputType.getSizes().size(); if (dimsSize < inputRank) { - auto constantOne = rewriter.create( - op.getLoc(), rewriter.getI64IntegerAttr(1)); + auto constantOne = Torch::ConstantIntOp::create( + rewriter, op.getLoc(), rewriter.getI64IntegerAttr(1)); for (auto i = dimsSize; i < inputRank; ++i) { dimsElements.insert(dimsElements.begin(), constantOne); } - repeats = rewriter.create( - op.getLoc(), + repeats = Torch::PrimListConstructOp::create( + rewriter, op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), dimsElements); } @@ -12033,8 +12122,9 @@ class DecomposeAtenReshapeAsOp : public OpRewritePattern { Value input = op.getSelf(); Value other = op.getOther(); - auto otherShape = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(context)), other); + auto otherShape = Torch::AtenSizeOp::create( + rewriter, loc, Torch::ListType::get(Torch::IntType::get(context)), + other); rewriter.replaceOpWithNewOp(op, op.getType(), input, otherShape); return success(); @@ -12063,7 +12153,7 @@ class DecomposeAtenLinalgNormOp : public OpRewritePattern { // default ord value is 2 for vector_norm auto ord = op.getOrd(); if (isa(ord.getType())) { - ord = rewriter.create(loc, rewriter.getI64IntegerAttr(2)); + ord = ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(2)); } rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), ord, op.getDim(), op.getKeepdim(), @@ -12083,43 +12173,43 @@ class DecomposeAtenFakeQuantizePerTensorAffineOp Location loc = op.getLoc(); MLIRContext *context = getContext(); - Value none = rewriter.create(loc); - Value falseVal = rewriter.create(loc, false); + Value none = ConstantNoneOp::create(rewriter, loc); + Value falseVal = ConstantBoolOp::create(rewriter, loc, false); Value one = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); // input/scale - Value divScale = rewriter.create( - loc, op.getType(), op.getSelf(), op.getScale()); + Value divScale = AtenDivScalarOp::create(rewriter, loc, op.getType(), + op.getSelf(), op.getScale()); // std::nearby_int(input/scale) - Value round = rewriter.create(loc, op.getType(), divScale); + Value round = AtenRoundOp::create(rewriter, loc, op.getType(), divScale); // std::nearby_int(input/scale) + zero_point - Value addZeroPoint = rewriter.create( - loc, op.getType(), round, op.getZeroPoint(), one); + Value addZeroPoint = AtenAddScalarOp::create(rewriter, loc, op.getType(), + round, op.getZeroPoint(), one); // max(quant_min, std::nearby_int(input/scale) + zero_point) auto si64Type = IntegerType::get(context, 64, IntegerType::Signed); auto tensorIntType = ValueTensorType::get(context, ArrayRef{1}, si64Type); - Value max = rewriter.create( - loc, op.getType(), addZeroPoint, - rewriter.create(loc, tensorIntType, op.getQuantMin(), - /*dtype=*/none, - /*device=*/none, - /*requires_grad=*/falseVal)); + Value max = AtenMaximumOp::create( + rewriter, loc, op.getType(), addZeroPoint, + AtenTensorIntOp::create(rewriter, loc, tensorIntType, op.getQuantMin(), + /*dtype=*/none, + /*device=*/none, + /*requires_grad=*/falseVal)); // min(quant_max, max(quant_min, std::nearby_int(input/scale) + zero_point)) - Value min = rewriter.create( - loc, op.getType(), max, - rewriter.create(loc, tensorIntType, op.getQuantMax(), - /*dtype=*/none, /*device=*/none, - /*requires_grad=*/falseVal)); + Value min = AtenMinimumOp::create( + rewriter, loc, op.getType(), max, + AtenTensorIntOp::create(rewriter, loc, tensorIntType, op.getQuantMax(), + /*dtype=*/none, /*device=*/none, + /*requires_grad=*/falseVal)); // min(quant_max, max(quant_min, std::nearby_int(input/scale) + zero_point)) // - zero_point - Value subZeroPoint = rewriter.create( - loc, op.getType(), min, op.getZeroPoint(), one); + Value subZeroPoint = AtenSubScalarOp::create(rewriter, loc, op.getType(), + min, op.getZeroPoint(), one); // (min(quant_max, max(quant_min, std::nearby_int(input/scale) + // zero_point)) - zero_point) * scale - Value result = rewriter.create( - loc, op.getType(), subZeroPoint, op.getScale()); + Value result = AtenMulScalarOp::create(rewriter, loc, op.getType(), + subZeroPoint, op.getScale()); rewriter.replaceOp(op, result); return success(); } @@ -12140,9 +12230,9 @@ class DecomposeAtenFakeQuantizePerTensorAffineCachemaskOp if (!op->getResult(1).use_empty()) return failure(); - auto newOp = rewriter.create( - op.getLoc(), op->getResult(0).getType(), op.getSelf(), op.getScale(), - op.getZeroPoint(), op.getQuantMin(), op.getQuantMax()); + auto newOp = AtenFakeQuantizePerTensorAffineOp::create( + rewriter, op.getLoc(), op->getResult(0).getType(), op.getSelf(), + op.getScale(), op.getZeroPoint(), op.getQuantMin(), op.getQuantMax()); rewriter.replaceAllUsesWith(op->getResult(0), newOp); rewriter.eraseOp(op); @@ -12168,11 +12258,9 @@ class DecomposeAten_FakeQuantizePerTensorAffineCachemaskTensorQparamsOp if (!op->getResult(1).use_empty()) return failure(); - auto newOp = - rewriter.create( - op.getLoc(), op->getResult(0).getType(), op.getSelf(), - op.getScale(), op.getZeroPoint(), op.getQuantMin(), - op.getQuantMax()); + auto newOp = AtenFakeQuantizePerTensorAffineTensorQparamsOp::create( + rewriter, op.getLoc(), op->getResult(0).getType(), op.getSelf(), + op.getScale(), op.getZeroPoint(), op.getQuantMin(), op.getQuantMax()); rewriter.replaceAllUsesWith(op->getResult(0), newOp); rewriter.eraseOp(op); @@ -12195,9 +12283,10 @@ class DecomposeAtenFakeQuantizePerChannelAffineCachemaskOp if (!op->getResult(1).use_empty()) return failure(); - auto newOp = rewriter.create( - op.getLoc(), op->getResult(0).getType(), op.getSelf(), op.getScale(), - op.getZeroPoint(), op.getAxis(), op.getQuantMin(), op.getQuantMax()); + auto newOp = AtenFakeQuantizePerChannelAffineOp::create( + rewriter, op.getLoc(), op->getResult(0).getType(), op.getSelf(), + op.getScale(), op.getZeroPoint(), op.getAxis(), op.getQuantMin(), + op.getQuantMax()); rewriter.replaceAllUsesWith(op->getResult(0), newOp); rewriter.eraseOp(op); @@ -12225,14 +12314,15 @@ class DecomposeAtenFMaxMinOp : public OpRewritePattern { Value other = op.getOther(); Value normalResult = - rewriter.create(loc, outType, self, other).getResult(); + AtenOpT::create(rewriter, loc, outType, self, other).getResult(); Value selfIsNan = - rewriter.create(loc, nanMaskType, self).getResult(); + Torch::AtenIsnanOp::create(rewriter, loc, nanMaskType, self) + .getResult(); Value otherIsNan = - rewriter.create(loc, nanMaskType, other) + Torch::AtenIsnanOp::create(rewriter, loc, nanMaskType, other) .getResult(); - normalResult = rewriter.create( - loc, outType, otherIsNan, self, normalResult); + normalResult = Torch::AtenWhereSelfOp::create( + rewriter, loc, outType, otherIsNan, self, normalResult); rewriter.replaceOpWithNewOp(op, outType, selfIsNan, other, normalResult); @@ -12258,11 +12348,11 @@ class DecomposeAtenThresholdOp : public OpRewritePattern { Value threshold = op.getThreshold(); Value value = op.getValue(); - auto comOp = rewriter.create( - loc, - selfType.getWithSizesAndDtype(selfType.getSizes(), - rewriter.getI1Type()), - self, threshold); + auto comOp = + AtenGtScalarOp::create(rewriter, loc, + selfType.getWithSizesAndDtype( + selfType.getSizes(), rewriter.getI1Type()), + self, threshold); rewriter.replaceOpWithNewOp(op, op.getType(), comOp, self, value); @@ -12310,41 +12400,41 @@ class DecomposeTorchvisionNmsOp : public OpRewritePattern { Value scores = op.getScores(); Value iouThreshold = op.getIouThreshold(); - Value cst0 = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - Value cst1 = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - Value cst2 = rewriter.create( - loc, rewriter.getI64IntegerAttr(2)); - Value cst4 = rewriter.create( - loc, rewriter.getI64IntegerAttr(4)); - Value cstNone = rewriter.create(loc); - Value cstTrue = - rewriter.create(loc, rewriter.getBoolAttr(true)); - Value cstFalse = rewriter.create( - loc, rewriter.getBoolAttr(false)); + Value cst0 = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(0)); + Value cst1 = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(1)); + Value cst2 = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(2)); + Value cst4 = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(4)); + Value cstNone = Torch::ConstantNoneOp::create(rewriter, loc); + Value cstTrue = Torch::ConstantBoolOp::create(rewriter, loc, + rewriter.getBoolAttr(true)); + Value cstFalse = Torch::ConstantBoolOp::create(rewriter, loc, + rewriter.getBoolAttr(false)); // Get number of boxes for the loop count auto boxesTensorType = dyn_cast(boxes.getType()); auto dType = boxesTensorType.getDtype(); int64_t boxesSize = boxesTensorType.getSizes()[0]; - Value len = rewriter.create(loc, boxes, /*dim=*/cst0); + Value len = AtenSizeIntOp::create(rewriter, loc, boxes, /*dim=*/cst0); // Calculate the area of each box: (x2 - x1) * (y2 - y1) auto sliceTy = rewriter.getType( SmallVector{boxesSize, 2}, dType); - Value lowSlice = rewriter.create( - loc, sliceTy, boxes, - /*dim=*/cst1, /*start=*/cst0, /*end=*/cst2, /*step=*/cst1); - Value highSlice = rewriter.create( - loc, sliceTy, boxes, - /*dim=*/cst1, /*start=*/cst2, /*end=*/cst4, /*step=*/cst1); - Value distance = rewriter.create( - loc, sliceTy, highSlice, lowSlice, cst1); + Value lowSlice = AtenSliceTensorOp::create(rewriter, loc, sliceTy, boxes, + /*dim=*/cst1, /*start=*/cst0, + /*end=*/cst2, /*step=*/cst1); + Value highSlice = AtenSliceTensorOp::create(rewriter, loc, sliceTy, boxes, + /*dim=*/cst1, /*start=*/cst2, + /*end=*/cst4, /*step=*/cst1); + Value distance = Torch::AtenSubTensorOp::create(rewriter, loc, sliceTy, + highSlice, lowSlice, cst1); auto areaTy = rewriter.getType( SmallVector{boxesSize}, dType); - Value area = rewriter.create( - loc, areaTy, distance, /*dim=*/cst1, /*keepdim=*/cstFalse, + Value area = Torch::AtenProdDimIntOp::create( + rewriter, loc, areaTy, distance, /*dim=*/cst1, /*keepdim=*/cstFalse, /*dtype=*/cstNone); // Sort scores in descending order @@ -12353,27 +12443,30 @@ class DecomposeTorchvisionNmsOp : public OpRewritePattern { auto intTensorType = scoresType.getWithSizesAndDtype( scoresType.getOptionalSizes(), IntegerType::get(context, 64, IntegerType::Signed)); - auto sortResult = rewriter.create( - loc, TypeRange({scores.getType(), intTensorType}), scores, + auto sortResult = Torch::AtenSortOp::create( + rewriter, loc, TypeRange({scores.getType(), intTensorType}), scores, /*dim=*/cst0, /*descending=*/cstTrue); // Create a mask to mark if we keep the boxes - Value lenShapeList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(context)), + Value lenShapeList = Torch::PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(Torch::IntType::get(context)), SmallVector{len}); - Value mask = rewriter.create( - loc, intTensorType, lenShapeList, cstNone, cstNone, cstNone, cstNone); - Value zeroShapeList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(context)), + Value mask = + Torch::AtenOnesOp::create(rewriter, loc, intTensorType, lenShapeList, + cstNone, cstNone, cstNone, cstNone); + Value zeroShapeList = Torch::PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(Torch::IntType::get(context)), SmallVector{cst1}); auto zeroTy = rewriter.getType( SmallVector{1}, rewriter.getIntegerType(64, /*signed=*/true)); - Value falseMask = rewriter.create( - loc, zeroTy, zeroShapeList, cstNone, cstNone, cstNone, cstNone); + Value falseMask = + Torch::AtenZerosOp::create(rewriter, loc, zeroTy, zeroShapeList, + cstNone, cstNone, cstNone, cstNone); // Create an empty tensor for result - Value result = rewriter.create( - loc, intTensorType, lenShapeList, /*dtype=*/cst4, /*layout=*/cstNone, + Value result = Torch::AtenEmptyMemoryFormatOp::create( + rewriter, loc, intTensorType, lenShapeList, /*dtype=*/cst4, + /*layout=*/cstNone, /*device=*/cstNone, /*pinMemory=*/cstNone, /*memoryFormat=*/cstNone); auto intTy = rewriter.getType(); @@ -12383,21 +12476,21 @@ class DecomposeTorchvisionNmsOp : public OpRewritePattern { rewriter.getType(SmallVector{1, 2}, dType); auto extractTy = rewriter.getType( SmallVector{1}, rewriter.getIntegerType(64, true)); - Value float0 = rewriter.create( - loc, rewriter.getFloatAttr(dType, 0.0)); + Value float0 = Torch::ConstantFloatOp::create( + rewriter, loc, rewriter.getFloatAttr(dType, 0.0)); auto scalarFloatType = rewriter.getType( SmallVector{1}, dType); - Value float0Tensor = rewriter.create( - loc, scalarFloatType, float0); + Value float0Tensor = Torch::PrimNumToTensorScalarOp::create( + rewriter, loc, scalarFloatType, float0); // 1. Loop through the boxes based on sorted indices // 2. Add the current box to result if it's not suppressed // 3. Calculate the IoUs with all boxes // 4. Loop through the rest boxes in sorted indices // 5. Suppress the box if the corresponding IoU is larger than threshold - auto loop1 = rewriter.create( - loc, TypeRange({intTensorType, intTensorType, intTy}), len, cstTrue, - ValueRange({mask, result, cst0})); + auto loop1 = Torch::PrimLoopOp::create( + rewriter, loc, TypeRange({intTensorType, intTensorType, intTy}), len, + cstTrue, ValueRange({mask, result, cst0})); { PatternRewriter::InsertionGuard guard(rewriter); Block *loopBody1 = rewriter.createBlock( @@ -12410,69 +12503,71 @@ class DecomposeTorchvisionNmsOp : public OpRewritePattern { Value curCnt = loopBody1->getArgument(3); // Extract the mask to check if the base box is suppressed - Value extract = rewriter.create( - loc, extractTy, mask1, /*dim=*/cst0, /*index=*/i); - Value scalar = rewriter.create(loc, intTy, extract); - Value iskept = rewriter.create( - loc, rewriter.getType(), scalar); - auto ifFilterOthers = rewriter.create( - loc, TypeRange({intTensorType, intTensorType, intTy}), iskept); + Value extract = AtenSelectIntOp::create(rewriter, loc, extractTy, mask1, + /*dim=*/cst0, /*index=*/i); + Value scalar = Torch::AtenItemOp::create(rewriter, loc, intTy, extract); + Value iskept = Torch::AtenBoolIntOp::create( + rewriter, loc, rewriter.getType(), scalar); + auto ifFilterOthers = Torch::PrimIfOp::create( + rewriter, loc, TypeRange({intTensorType, intTensorType, intTy}), + iskept); { PatternRewriter::InsertionGuard guard(rewriter); rewriter.createBlock(&ifFilterOthers.getThenRegion(), ifFilterOthers.getThenRegion().begin()); // Scatter the selected indices into result - Value extractIdx1 = rewriter.create( - loc, extractTy, sortResult.getResults()[1], /*dim=*/cst0, + Value extractIdx1 = AtenSelectIntOp::create( + rewriter, loc, extractTy, sortResult.getResults()[1], /*dim=*/cst0, /*index=*/i); - Value next = rewriter.create(loc, curCnt, cst1); - Value updatedResult = rewriter.create( - loc, intTensorType, curResult, extractIdx1, /*dim=*/cst0, + Value next = Torch::AtenAddIntOp::create(rewriter, loc, curCnt, cst1); + Value updatedResult = Torch::AtenSliceScatterOp::create( + rewriter, loc, intTensorType, curResult, extractIdx1, /*dim=*/cst0, /*start=*/curCnt, /*end=*/next, /*step=*/cst1); // Get the coordinates of base box Value idx1 = - rewriter.create(loc, intTy, extractIdx1); - Value idx1End = rewriter.create(loc, idx1, cst1); - Value curBox = rewriter.create( - loc, rowSliceTy, boxes, + Torch::AtenItemOp::create(rewriter, loc, intTy, extractIdx1); + Value idx1End = Torch::AtenAddIntOp::create(rewriter, loc, idx1, cst1); + Value curBox = AtenSliceTensorOp::create( + rewriter, loc, rowSliceTy, boxes, /*dim=*/cst0, /*start=*/idx1, /*end=*/idx1End, /*step=*/cst1); // Calculate IoUs: intersectionArea / unionArea // Intersection area = intersectionWidth * intersectionHeight - Value point1 = rewriter.create( - loc, pointTy, curBox, - /*dim=*/cst1, /*start=*/cst0, /*end=*/cst2, /*step=*/cst1); - Value point2 = rewriter.create( - loc, pointTy, curBox, - /*dim=*/cst1, /*start=*/cst2, /*end=*/cst4, /*step=*/cst1); - Value innerLow = rewriter.create( - loc, sliceTy, lowSlice, point1); - Value innerHigh = rewriter.create( - loc, sliceTy, highSlice, point2); - Value innerDistance = rewriter.create( - loc, sliceTy, innerHigh, innerLow, cst1); - innerDistance = rewriter.create( - loc, sliceTy, innerDistance, float0Tensor); - Value intersectionArea = rewriter.create( - loc, areaTy, innerDistance, /*dim=*/cst1, /*keepdim=*/cstFalse, + Value point1 = AtenSliceTensorOp::create(rewriter, loc, pointTy, curBox, + /*dim=*/cst1, /*start=*/cst0, + /*end=*/cst2, /*step=*/cst1); + Value point2 = AtenSliceTensorOp::create(rewriter, loc, pointTy, curBox, + /*dim=*/cst1, /*start=*/cst2, + /*end=*/cst4, /*step=*/cst1); + Value innerLow = Torch::AtenMaximumOp::create(rewriter, loc, sliceTy, + lowSlice, point1); + Value innerHigh = Torch::AtenMinimumOp::create(rewriter, loc, sliceTy, + highSlice, point2); + Value innerDistance = Torch::AtenSubTensorOp::create( + rewriter, loc, sliceTy, innerHigh, innerLow, cst1); + innerDistance = Torch::AtenMaximumOp::create( + rewriter, loc, sliceTy, innerDistance, float0Tensor); + Value intersectionArea = Torch::AtenProdDimIntOp::create( + rewriter, loc, areaTy, innerDistance, /*dim=*/cst1, + /*keepdim=*/cstFalse, /*dtype=*/cstNone); - Value iEnd = rewriter.create(loc, i, cst1); - Value curArea = rewriter.create( - loc, scalarFloatType, area, + Value iEnd = Torch::AtenAddIntOp::create(rewriter, loc, i, cst1); + Value curArea = AtenSliceTensorOp::create( + rewriter, loc, scalarFloatType, area, /*dim=*/cst0, /*start=*/i, /*end=*/iEnd, /*step=*/cst1); // Union area = area1 + area2 - intersectionArea - Value unionArea = rewriter.create( - loc, areaTy, area, curArea, cst1); - unionArea = rewriter.create( - loc, areaTy, unionArea, intersectionArea, cst1); - Value iou = rewriter.create( - loc, areaTy, intersectionArea, unionArea); + Value unionArea = Torch::AtenAddTensorOp::create(rewriter, loc, areaTy, + area, curArea, cst1); + unionArea = Torch::AtenSubTensorOp::create( + rewriter, loc, areaTy, unionArea, intersectionArea, cst1); + Value iou = Torch::AtenDivTensorOp::create(rewriter, loc, areaTy, + intersectionArea, unionArea); // Loop through the rest of boxes in sorted indices - auto loop2 = rewriter.create(loc, intTensorType, len, - cstTrue, mask1); + auto loop2 = Torch::PrimLoopOp::create(rewriter, loc, intTensorType, + len, cstTrue, mask1); { PatternRewriter::InsertionGuard guard(rewriter); Block *loopBody2 = rewriter.createBlock( @@ -12482,79 +12577,81 @@ class DecomposeTorchvisionNmsOp : public OpRewritePattern { Value mask2 = loopBody2->getArgument(1); // Check if current index is out of range - j = rewriter.create(loc, j, i); - j = rewriter.create(loc, j, cst1); - Value isInRange = rewriter.create(loc, j, len); - auto ifCalculateIou = rewriter.create( - loc, TypeRange({intTensorType}), isInRange); + j = Torch::AtenAddIntOp::create(rewriter, loc, j, i); + j = Torch::AtenAddIntOp::create(rewriter, loc, j, cst1); + Value isInRange = Torch::AtenLtIntOp::create(rewriter, loc, j, len); + auto ifCalculateIou = Torch::PrimIfOp::create( + rewriter, loc, TypeRange({intTensorType}), isInRange); { PatternRewriter::InsertionGuard guard(rewriter); rewriter.createBlock(&ifCalculateIou.getThenRegion(), ifCalculateIou.getThenRegion().begin()); // Retrieve IoU and check if suppress the box - Value extractIdx2 = rewriter.create( - loc, extractTy, sortResult.getResults()[1], /*dim=*/cst0, + Value extractIdx2 = AtenSelectIntOp::create( + rewriter, loc, extractTy, sortResult.getResults()[1], + /*dim=*/cst0, /*index=*/j); Value idx2 = - rewriter.create(loc, intTy, extractIdx2); + Torch::AtenItemOp::create(rewriter, loc, intTy, extractIdx2); Value idx2End = - rewriter.create(loc, idx2, cst1); - Value curIoU = rewriter.create( - loc, scalarFloatType, iou, + Torch::AtenAddIntOp::create(rewriter, loc, idx2, cst1); + Value curIoU = AtenSliceTensorOp::create( + rewriter, loc, scalarFloatType, iou, /*dim=*/cst0, /*start=*/idx2, /*end=*/idx2End, /*step=*/cst1); - curIoU = rewriter.create( - loc, rewriter.getType(), curIoU); - Value isSuppressed = rewriter.create( - loc, curIoU, iouThreshold); + curIoU = Torch::AtenItemOp::create( + rewriter, loc, rewriter.getType(), curIoU); + Value isSuppressed = Torch::AtenGtFloatOp::create( + rewriter, loc, curIoU, iouThreshold); - auto ifUnmask = rewriter.create( - loc, TypeRange({intTensorType}), isSuppressed); + auto ifUnmask = Torch::PrimIfOp::create( + rewriter, loc, TypeRange({intTensorType}), isSuppressed); { PatternRewriter::InsertionGuard guard(rewriter); rewriter.createBlock(&ifUnmask.getThenRegion(), ifUnmask.getThenRegion().begin()); // Update the mask if suppress - Value jEnd = rewriter.create(loc, j, cst1); - Value updatedMask = rewriter.create( - loc, intTensorType, mask2, falseMask, /*dim=*/cst0, + Value jEnd = Torch::AtenAddIntOp::create(rewriter, loc, j, cst1); + Value updatedMask = Torch::AtenSliceScatterOp::create( + rewriter, loc, intTensorType, mask2, falseMask, /*dim=*/cst0, /*start=*/j, /*end=*/jEnd, /*step=*/cst1); - rewriter.create(loc, updatedMask); + Torch::PrimIfYieldOp::create(rewriter, loc, updatedMask); } { PatternRewriter::InsertionGuard guard(rewriter); rewriter.createBlock(&ifUnmask.getElseRegion(), ifUnmask.getElseRegion().begin()); - rewriter.create(loc, mask2); + Torch::PrimIfYieldOp::create(rewriter, loc, mask2); } - rewriter.create(loc, ifUnmask.getResult(0)); + Torch::PrimIfYieldOp::create(rewriter, loc, ifUnmask.getResult(0)); } { PatternRewriter::InsertionGuard guard(rewriter); rewriter.createBlock(&ifCalculateIou.getElseRegion(), ifCalculateIou.getElseRegion().begin()); - rewriter.create(loc, mask2); + Torch::PrimIfYieldOp::create(rewriter, loc, mask2); } - rewriter.create( - loc, cstTrue, ifCalculateIou.getResult(0)); + Torch::PrimLoopConditionOp::create(rewriter, loc, cstTrue, + ifCalculateIou.getResult(0)); } - rewriter.create( - loc, ValueRange({loop2.getResult(0), updatedResult, next})); + Torch::PrimIfYieldOp::create( + rewriter, loc, + ValueRange({loop2.getResult(0), updatedResult, next})); } { PatternRewriter::InsertionGuard guard(rewriter); rewriter.createBlock(&ifFilterOthers.getElseRegion(), ifFilterOthers.getElseRegion().begin()); - rewriter.create( - loc, ValueRange({mask1, curResult, curCnt})); + Torch::PrimIfYieldOp::create(rewriter, loc, + ValueRange({mask1, curResult, curCnt})); } - rewriter.create(loc, cstTrue, - ifFilterOthers.getResults()); + Torch::PrimLoopConditionOp::create(rewriter, loc, cstTrue, + ifFilterOthers.getResults()); } rewriter.replaceOpWithNewOp( @@ -12594,7 +12691,7 @@ class DecomposeAtenConstrainRangeForSizeOp if (isa(min.getType())) { // Set min value to 0 - min = rewriter.create(loc, 0); + min = Torch::ConstantIntOp::create(rewriter, loc, 0); } else { // Check if min value is a constant if (!matchPattern(min, m_TorchConstantInt(&minValue))) @@ -12635,9 +12732,9 @@ class DecomposeAten_AssertScalarOp auto assertCond = op.getSelf(); if (isa(assertCond.getType())) - assertCond = rewriter.create(loc, assertCond); + assertCond = AtenBoolIntOp::create(rewriter, loc, assertCond); else if (isa(assertCond.getType())) - assertCond = rewriter.create(loc, assertCond); + assertCond = AtenBoolFloatOp::create(rewriter, loc, assertCond); assert(isa(assertCond.getType()) && "Unhandled type encountered in aten._assert_scalar op"); @@ -12683,15 +12780,17 @@ class DecomposeAtenRoundDecimalsOp Value scale; if (decimals) { auto scaleVal = pow(10, decimals); - scale = rewriter.create( - loc, rewriter.getF64FloatAttr(scaleVal)); - newOp = rewriter.create(loc, op.getType(), input, scale); + scale = ConstantFloatOp::create(rewriter, loc, + rewriter.getF64FloatAttr(scaleVal)); + newOp = + AtenMulScalarOp::create(rewriter, loc, op.getType(), input, scale); } - newOp = rewriter.create(loc, op.getType(), newOp); + newOp = AtenRoundOp::create(rewriter, loc, op.getType(), newOp); if (decimals) { - newOp = rewriter.create(loc, op.getType(), newOp, scale); + newOp = + AtenDivScalarOp::create(rewriter, loc, op.getType(), newOp, scale); } rewriter.replaceOp(op, newOp); @@ -12725,20 +12824,21 @@ class DecomposeAtenBroadcastTensorsOp Type broadcastType = ValueTensorType::get( op.getContext(), llvm::ArrayRef(broadcastShape), dtype); - Value broadcastShapeTorchList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), + Value broadcastShapeTorchList = PrimListConstructOp::create( + rewriter, loc, + Torch::ListType::get(Torch::IntType::get(op.getContext())), broadcastShapeValue); SmallVector broadcastedValues; for (int64_t i = 0; i < numTensors; i++) { auto inputTensor = tensors[i]; - auto broadcastedVal = rewriter.create( - loc, broadcastType, inputTensor, broadcastShapeTorchList); + auto broadcastedVal = AtenBroadcastToOp::create( + rewriter, loc, broadcastType, inputTensor, broadcastShapeTorchList); broadcastedValues.push_back(broadcastedVal); } - auto broadcastedValuesList = rewriter.create( - loc, Torch::ListType::get(broadcastType), broadcastedValues); + auto broadcastedValuesList = Torch::PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(broadcastType), broadcastedValues); rewriter.replaceOp(op, broadcastedValuesList); return success(); @@ -12811,7 +12911,7 @@ class DecomposeAtenAsStridedOp : public OpRewritePattern { int64_t resultRank = sizesInts.size(); Value cstZero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0)); if (inputRank > 1) { // If the input is not a 1-d tensor, we need to flatten it // to a 1D tensor before applying the strided indexing. @@ -12828,16 +12928,16 @@ class DecomposeAtenAsStridedOp : public OpRewritePattern { cast(inputType.getWithSizesAndDtype( {flattenedInputSize}, inputType.getOptionalDtype())); - Value end = rewriter.create( - loc, rewriter.getI64IntegerAttr(inputRank - 1)); - input = rewriter.create(loc, flattenedInputTy, - input, cstZero, end); + Value end = ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(inputRank - 1)); + input = AtenFlattenUsingIntsOp::create(rewriter, loc, flattenedInputTy, + input, cstZero, end); } Value cstOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); Value cstMinusOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(-1)); SmallVector viewShapeInts(resultRank, 1); SmallVector viewShapeListElems(resultRank, cstOne); @@ -12846,31 +12946,32 @@ class DecomposeAtenAsStridedOp : public OpRewritePattern { Value finalIndices; for (unsigned dim = 0; dim < sizesInts.size(); dim++) { int64_t size = sizesInts[dim]; - Value cstNone = rewriter.create(loc); - Value end = - rewriter.create(loc, rewriter.getI64IntegerAttr(size)); + Value cstNone = ConstantNoneOp::create(rewriter, loc); + Value end = ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(size)); auto arangeType = ValueTensorType::get(context, llvm::ArrayRef(size), si64Type); - Value index = rewriter.create( - loc, arangeType, end, cstNone, cstNone, cstNone, cstNone); + Value index = Torch::AtenArangeOp::create( + rewriter, loc, arangeType, end, cstNone, cstNone, cstNone, cstNone); // Set the current dimension to -1 for broadcasting viewShapeInts[dim] = -1; viewShapeListElems[dim] = cstMinusOne; - Value viewShapeList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(context)), + Value viewShapeList = Torch::PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(Torch::IntType::get(context)), viewShapeListElems); auto viewType = ValueTensorType::get( context, llvm::ArrayRef(viewShapeInts), si64Type); - index = rewriter.create(loc, viewType, index, viewShapeList); + index = AtenViewOp::create(rewriter, loc, viewType, index, viewShapeList); // Multiply the index with the stride for the current dimension - Value cstStride = rewriter.create( - loc, rewriter.getI64IntegerAttr(stridesInts[dim])); - index = rewriter.create(loc, viewType, index, cstStride); + Value cstStride = ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(stridesInts[dim])); + index = + AtenMulScalarOp::create(rewriter, loc, viewType, index, cstStride); // Reset the current dimension to 1 for the next iteration viewShapeInts[dim] = 1; @@ -12889,8 +12990,8 @@ class DecomposeAtenAsStridedOp : public OpRewritePattern { Type broadcastType = ValueTensorType::get( context, llvm::ArrayRef(broadcastShape), si64Type); - finalIndices = rewriter.create( - loc, broadcastType, finalIndices, index, cstOne); + finalIndices = AtenAddTensorOp::create(rewriter, loc, broadcastType, + finalIndices, index, cstOne); } int64_t flattenedResultSize = 1; @@ -12898,44 +12999,45 @@ class DecomposeAtenAsStridedOp : public OpRewritePattern { flattenedResultSize *= size; // Flattening the indices and adding the storage offset - finalIndices = rewriter.create( - loc, + finalIndices = AtenFlattenUsingIntsOp::create( + rewriter, loc, ValueTensorType::get(context, llvm::ArrayRef(flattenedResultSize), si64Type), finalIndices, cstZero, cstMinusOne); // -1 means flatten all if (storageOffset != 0) { - Value cstStorageOffset = rewriter.create( - loc, rewriter.getI64IntegerAttr(storageOffset)); - finalIndices = rewriter.create( - loc, finalIndices.getType(), finalIndices, cstStorageOffset, cstOne); + Value cstStorageOffset = ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(storageOffset)); + finalIndices = + AtenAddScalarOp::create(rewriter, loc, finalIndices.getType(), + finalIndices, cstStorageOffset, cstOne); } // Index the flattened input tensor Type listElemType = inputType.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt, /*optionalDtype=*/nullptr); - Value indicesList = rewriter.create( - loc, Torch::ListType::get(listElemType), + Value indicesList = Torch::PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(listElemType), SmallVector{finalIndices}); auto flattenedResultTy = ValueTensorType::get(context, llvm::ArrayRef(flattenedResultSize), inputType.getOptionalDtype()); - Value result = rewriter.create(loc, flattenedResultTy, - input, indicesList); + Value result = AtenIndexTensorOp::create(rewriter, loc, flattenedResultTy, + input, indicesList); // Reshape the result to the desired output size SmallVector sizesIntsValues; for (int64_t size : sizesInts) { - sizesIntsValues.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(size))); + sizesIntsValues.push_back(ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(size))); } - Value resultSizeList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(context)), + Value resultSizeList = Torch::PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(Torch::IntType::get(context)), sizesIntsValues); result = - rewriter.create(loc, op.getType(), result, resultSizeList); + AtenViewOp::create(rewriter, loc, op.getType(), result, resultSizeList); rewriter.replaceOp(op, result); return success(); diff --git a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp index da06e1c59a75..e418a4b08ec0 100644 --- a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp +++ b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp @@ -148,14 +148,14 @@ class QuantizeOperandsPastCommutingOps : public OpRewritePattern { Value quantPadValue; if (isa(floatPadValue.getType())) quantPadValue = - rewriter.create(loc, chain.zeroPoint); + AtenFloatScalarOp::create(rewriter, loc, chain.zeroPoint); else { floatPadValue = - rewriter.create(loc, floatPadValue); - quantPadValue = rewriter.create( - loc, floatPadValue, chain.scale); - quantPadValue = rewriter.create( - loc, quantPadValue, chain.zeroPoint); + AtenFloatScalarOp::create(rewriter, loc, floatPadValue); + quantPadValue = Torch::AtenDivFloatOp::create( + rewriter, loc, floatPadValue, chain.scale); + quantPadValue = Torch::AtenAddFloatIntOp::create( + rewriter, loc, quantPadValue, chain.zeroPoint); } // clamp pad value to qint range if (auto intType = dyn_cast(intDType)) { @@ -165,20 +165,21 @@ class QuantizeOperandsPastCommutingOps : public OpRewritePattern { "quantized int bitwidth should be less than 64"); int64_t minInt = isSigned ? -(1 << (width - 1)) : 0; int64_t maxInt = isSigned ? -minInt - 1 : ((1 << width) - 1); - Value minQValueFloat = rewriter.create( - loc, rewriter.getF64FloatAttr(minInt)); - Value maxQValueFloat = rewriter.create( - loc, rewriter.getF64FloatAttr(maxInt)); + Value minQValueFloat = ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(minInt)); + Value maxQValueFloat = ConstantFloatOp::create( + rewriter, loc, rewriter.getF64FloatAttr(maxInt)); SmallVector emptyShape; auto floatTensorType = rewriter.getType( emptyShape, rewriter.getF64Type()); Value quantPadValueTensor = createRank0Tensor( rewriter, loc, floatTensorType, quantPadValue); - Value clampedTensor = rewriter.create( - loc, floatTensorType, quantPadValueTensor, minQValueFloat, - maxQValueFloat); - quantPadValue = rewriter.create( - loc, rewriter.getType(), clampedTensor); + Value clampedTensor = Torch::AtenClampOp::create( + rewriter, loc, floatTensorType, quantPadValueTensor, + minQValueFloat, maxQValueFloat); + quantPadValue = Torch::AtenItemOp::create( + rewriter, loc, rewriter.getType(), + clampedTensor); } // quantPadValue is a float, but will get converted/truncated currOperands.back() = quantPadValue; @@ -203,8 +204,9 @@ class QuantizeOperandsPastCommutingOps : public OpRewritePattern { cast(chain.dequantOpd.getType()).getOptionalDtype(); auto newMPTQTType = rewriter.getType( cast(operands[i].getType()).getSizes(), qTorchType); - operands[i] = rewriter.create( - loc, newMPTQTType, oldOpd, MPTQTOperands[1], MPTQTOperands[2]); + operands[i] = Aten_MakePerTensorQuantizedTensorOp::create( + rewriter, loc, newMPTQTType, oldOpd, MPTQTOperands[1], + MPTQTOperands[2]); } rewriter.replaceOpWithNewOp(op, op.getType(), operands); @@ -248,11 +250,11 @@ template class QuantizeBias : public OpRewritePattern { return failure(); } - Value biasScale = rewriter.create( - op.getLoc(), lhsScale.getType(), lhsScale, rhsScale); + Value biasScale = AtenMulFloatOp::create( + rewriter, op.getLoc(), lhsScale.getType(), lhsScale, rhsScale); - Value zero = rewriter.create( - op.getLoc(), rewriter.getType(), + Value zero = Torch::ConstantIntOp::create( + rewriter, op.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); auto qi32Ty = rewriter.getType(); @@ -261,10 +263,10 @@ template class QuantizeBias : public OpRewritePattern { auto newBiasTy = rewriter.getType(biasTy.getOptionalSizes(), qi32Ty); Value dtype = getDtypeIntValueForType(rewriter, op.getLoc(), qi32Ty); - bias = rewriter.create( - op.getLoc(), newBiasTy, bias, biasScale, zero, dtype); - bias = rewriter.create( - op.getLoc(), + bias = AtenQuantizePerTensorOp::create(rewriter, op.getLoc(), newBiasTy, + bias, biasScale, zero, dtype); + bias = AtenIntReprOp::create( + rewriter, op.getLoc(), rewriter.getType( biasTy.getOptionalSizes(), rewriter.getIntegerType(32, IntegerType::Signed)), @@ -275,12 +277,12 @@ template class QuantizeBias : public OpRewritePattern { auto convTy = rewriter.getType( resultTy.getOptionalSizes(), rewriter.getIntegerType(32, IntegerType::Signed)); - auto conv = rewriter.create(op.getLoc(), convTy, operands); + auto conv = SrcOp::create(rewriter, op.getLoc(), convTy, operands); auto convQTy = rewriter.getType(resultTy.getOptionalSizes(), qi32Ty); - auto makeOut = rewriter.create( - op.getLoc(), convQTy, conv, biasScale, zero); + auto makeOut = Aten_MakePerTensorQuantizedTensorOp::create( + rewriter, op.getLoc(), convQTy, conv, biasScale, zero); rewriter.replaceOpWithNewOp(op, op.getType(), makeOut); @@ -322,33 +324,34 @@ class QuantizeAccumulator : public OpRewritePattern { return failure(); // Quantize the bias input to the expected result: - Value zero = rewriter.create( - op.getLoc(), rewriter.getType(), + Value zero = Torch::ConstantIntOp::create( + rewriter, op.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); auto qi32Ty = rewriter.getType(); - Value biasScale = rewriter.create( - op.getLoc(), lhsScale.getType(), lhsScale, rhsScale); + Value biasScale = AtenMulFloatOp::create( + rewriter, op.getLoc(), lhsScale.getType(), lhsScale, rhsScale); // Update the quantied type: llvm::SmallVector operands(op.getOperands()); auto newResultTy = rewriter.getType(resultTy.getOptionalSizes(), qi32Ty); - auto conv = rewriter.create(op.getLoc(), newResultTy, operands); + auto conv = SrcOp::create(rewriter, op.getLoc(), newResultTy, operands); // Attach the quantize information to the resulting qint32: auto intReprTy = rewriter.getType( resultTy.getOptionalSizes(), rewriter.getIntegerType(32, IntegerType::Signed)); - auto intRepr = rewriter.create(op.getLoc(), intReprTy, conv); + auto intRepr = + AtenIntReprOp::create(rewriter, op.getLoc(), intReprTy, conv); auto quantTy = rewriter.getType(resultTy.getOptionalSizes(), qi32Ty); - auto quant = rewriter.create( - op.getLoc(), quantTy, intRepr, biasScale, zero); + auto quant = Aten_MakePerTensorQuantizedTensorOp::create( + rewriter, op.getLoc(), quantTy, intRepr, biasScale, zero); auto dequant = - rewriter.create(op.getLoc(), resultTy, quant); + AtenDequantizeTensorOp::create(rewriter, op.getLoc(), resultTy, quant); rewriter.replaceOp(op, dequant); return success(); @@ -397,23 +400,24 @@ class QuantizeResultLikeOperand : public OpRewritePattern { // set SrcOp type to use quantized dtype from input auto newResultTy = rewriter.getType(resultTy.getOptionalSizes(), qDtype); - auto newResult = rewriter.create(op.getLoc(), newResultTy, operands); + auto newResult = + SrcOp::create(rewriter, op.getLoc(), newResultTy, operands); // int repr to get non quantized int type result auto intReprTy = rewriter.getType( resultTy.getOptionalSizes(), intReprDtype); auto intRepr = - rewriter.create(op.getLoc(), intReprTy, newResult); + AtenIntReprOp::create(rewriter, op.getLoc(), intReprTy, newResult); // requantize so the scale and zero-point info can be attached auto quantTy = rewriter.getType(resultTy.getOptionalSizes(), qDtype); - auto quant = rewriter.create( - op.getLoc(), quantTy, intRepr, inputScale, inputZeroPoint); + auto quant = Aten_MakePerTensorQuantizedTensorOp::create( + rewriter, op.getLoc(), quantTy, intRepr, inputScale, inputZeroPoint); // dequant back to original dtype auto dequant = - rewriter.create(op.getLoc(), resultTy, quant); + AtenDequantizeTensorOp::create(rewriter, op.getLoc(), resultTy, quant); rewriter.replaceOp(op, dequant); return success(); } diff --git a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp index a63b6f3196ee..d47220e348ea 100644 --- a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp +++ b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp @@ -172,9 +172,9 @@ class ObjectGraphInfo { } else if (usedSlots.find(slot) != usedSlots.end()) { // Only create the GlobalSlotOp if the slot is used at all. std::string linkageName = llvm::join(nameStack, "."); - auto globalSlot = globalSlotBuilder.create( - slot.getLoc(), linkageName, - /*sym_visibility=*/nullptr, attr.getType()); + auto globalSlot = + GlobalSlotOp::create(globalSlotBuilder, slot.getLoc(), linkageName, + /*sym_visibility=*/nullptr, attr.getType()); if (attr.getIsPrivate()) globalSlot.setVisibility(SymbolTable::Visibility::Private); assert(slotToGlobalSlot.find(slot) == slotToGlobalSlot.end()); @@ -230,7 +230,7 @@ createGlobalSlotModuleInitializer(ModuleOp module, SymbolTable &symbolTable, ObjectGraphInfo &objectGraphInfo) { auto builder = OpBuilder::atBlockBegin(module.getBody()); auto moduleInitializer = - builder.create(module.getLoc()); + GlobalSlotModuleInitializerOp::create(builder, module.getLoc()); Block *body = builder.createBlock(&moduleInitializer.getInitializer()); builder.setInsertionPointToEnd(body); SmallVector opsToMove; @@ -254,8 +254,8 @@ createGlobalSlotModuleInitializer(ModuleOp module, SymbolTable &symbolTable, slotSymNames.push_back(FlatSymbolRefAttr::get(symName)); initialValues.push_back(mapping.lookup(initializer)); } - builder.create( - moduleInitializer.getLoc(), + InitializeGlobalSlotsOp::create( + builder, moduleInitializer.getLoc(), ArrayAttr::get(module.getContext(), slotSymNames), initialValues); return success(); } @@ -504,8 +504,9 @@ static LogicalResult rewriteMonomorphizedFuncClone( if (slot.getName() == op.getName()) affectedSlot = slot; } - OpBuilder(op).create( - op.getLoc(), + OpBuilder builder(op); + GlobalSlotSetOp::create( + builder, op.getLoc(), objectGraphInfo.getGlobalSlotFor(affectedSlot).getSymName(), op.getValue()); toErase.push_back(op); @@ -520,8 +521,9 @@ static LogicalResult rewriteMonomorphizedFuncClone( if (slot.getName() == op.getName()) affectedSlot = slot; } - auto newOp = OpBuilder(op).create( - op.getLoc(), op.getType(), + OpBuilder builder(op); + auto newOp = GlobalSlotGetOp::create( + builder, op.getLoc(), op.getType(), objectGraphInfo.getGlobalSlotFor(affectedSlot).getSymName()); op.replaceAllUsesWith(&*newOp); } @@ -539,8 +541,9 @@ static LogicalResult rewriteMonomorphizedFuncClone( return !isa(v.getType()); })); assert(newFuncs.find(monomorphization) != newFuncs.end()); - auto newOp = OpBuilder(op).create( - op.getLoc(), newFuncs[monomorphization], newArguments); + OpBuilder builder(op); + auto newOp = func::CallOp::create(builder, op.getLoc(), + newFuncs[monomorphization], newArguments); op.replaceAllUsesWith(newOp); toErase.push_back(op); return WalkResult::advance(); diff --git a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp index 4a15083ae083..12660cfee47c 100644 --- a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp +++ b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp @@ -407,8 +407,8 @@ class InlineGlobalSlotsPass } { OpBuilder builder(initialize); - builder.create( - initialize.getLoc(), + Torch::InitializeGlobalSlotsOp::create( + builder, initialize.getLoc(), ArrayAttr::get(module.getContext(), newSlotSymNames), newInitialValues); } diff --git a/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp index 0e3cda033a18..e5c279415c7f 100644 --- a/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp +++ b/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp @@ -45,13 +45,13 @@ class MatchQuantizeOperator : public OpRewritePattern { auto qTy = rewriter.getType(resultTy.getOptionalSizes(), qeTy); - Value quant = rewriter.create( - op.getLoc(), qTy, + Value quant = AtenQuantizePerTensorOp::create( + rewriter, op.getLoc(), qTy, /*self=*/op.getOperand(0), /*scale=*/op.getOperand(1), /*zero_point=*/op.getOperand(2), /*dtype=*/op.getOperand(5)); if (qTy != resultTy) { - quant = rewriter.create(op.getLoc(), resultTy, quant); + quant = AtenIntReprOp::create(rewriter, op.getLoc(), resultTy, quant); } rewriter.replaceOpWithNewOp( @@ -62,8 +62,8 @@ class MatchQuantizeOperator : public OpRewritePattern { auto prepareDequantize = [&](Value quantMin, Value quantMax, Value &clamp, Type &qTy) { clamp = - rewriter.create(op.getLoc(), op.getOperand(0).getType(), - op.getOperand(0), quantMin, quantMax); + AtenClampOp::create(rewriter, op.getLoc(), op.getOperand(0).getType(), + op.getOperand(0), quantMin, quantMax); auto clampTy = cast(clamp.getType()); if (!clampTy.hasDtype()) @@ -88,8 +88,9 @@ class MatchQuantizeOperator : public OpRewritePattern { qTy))) return failure(); - auto quant = rewriter.create( - op.getLoc(), qTy, clamp, op.getOperand(1), op.getOperand(2)); + auto quant = Aten_MakePerTensorQuantizedTensorOp::create( + rewriter, op.getLoc(), qTy, clamp, op.getOperand(1), + op.getOperand(2)); rewriter.replaceOpWithNewOp( op, op.getResultTypes(), quant); return success(); @@ -101,8 +102,8 @@ class MatchQuantizeOperator : public OpRewritePattern { if (failed(prepareDequantize(op.getOperand(4), op.getOperand(5), clamp, qTy))) return failure(); - auto quant = rewriter.create( - op.getLoc(), qTy, clamp, op.getOperand(1), op.getOperand(2), + auto quant = Aten_MakePerChannelQuantizedTensorOp::create( + rewriter, op.getLoc(), qTy, clamp, op.getOperand(1), op.getOperand(2), op.getOperand(3)); rewriter.replaceOpWithNewOp(op, op.getResultTypes(), quant); diff --git a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp index d954731fc4e0..1d7c926473c2 100644 --- a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp @@ -25,10 +25,10 @@ namespace { // a/b's type should be !torch.int Value getIntCeilDiv(PatternRewriter &rewriter, Location loc, Value a, Value b) { Value cstOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); - Value dividend = rewriter.create(loc, a, b); - dividend = rewriter.create(loc, dividend, cstOne); - Value result = rewriter.create(loc, dividend, b); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); + Value dividend = AtenAddIntOp::create(rewriter, loc, a, b); + dividend = AtenSubIntOp::create(rewriter, loc, dividend, cstOne); + Value result = AtenFloordivIntOp::create(rewriter, loc, dividend, b); return result; } @@ -65,25 +65,25 @@ class RecomposeSliceCopy_ : public OpRewritePattern { Value newStart = sliceOp.getStart(); Value newEnd = sliceOp.getEnd(); - Value dimSize = rewriter.create( - op.getLoc(), sliceOp.getSelf(), sliceOp.getDim()); + Value dimSize = AtenSizeIntOp::create(rewriter, op.getLoc(), + sliceOp.getSelf(), sliceOp.getDim()); if (end < 0) { - newEnd = - rewriter.create(op.getLoc(), dimSize, sliceOp.getEnd()); + newEnd = AtenAddIntOp::create(rewriter, op.getLoc(), dimSize, + sliceOp.getEnd()); } - newStart = rewriter.create(op.getLoc(), newStart, dimSize); - newEnd = rewriter.create(op.getLoc(), newEnd, dimSize); + newStart = PrimMinIntOp::create(rewriter, op.getLoc(), newStart, dimSize); + newEnd = PrimMinIntOp::create(rewriter, op.getLoc(), newEnd, dimSize); - Value noneVal = rewriter.create(op.getLoc()); - Value falseVal = rewriter.create(op.getLoc(), false); + Value noneVal = ConstantNoneOp::create(rewriter, op.getLoc()); + Value falseVal = ConstantBoolOp::create(rewriter, op.getLoc(), false); // Create IndexPut_Op BaseTensorType tensorType = cast(op.getType()); Type rangeType = tensorType.getWithSizesAndDtype( {kUnknownSize}, tensorType.getOptionalDtype()); - Value range = rewriter.create( - op.getLoc(), rangeType, newStart, newEnd, sliceOp.getStep(), + Value range = AtenArangeStartStepOp::create( + rewriter, op.getLoc(), rangeType, newStart, newEnd, sliceOp.getStep(), /*dtype=*/noneVal, /*layout=*/noneVal, /*device=*/noneVal, /*pin_memory=*/noneVal); @@ -93,8 +93,8 @@ class RecomposeSliceCopy_ : public OpRewritePattern { indicesVector.push_back(range); Type indicesType = tensorType.getWithSizesAndDtype( /*optionalSizes=*/std::nullopt, /*optionalDtype=*/nullptr); - Value indices = rewriter.create( - op.getLoc(), + Value indices = PrimListConstructOp::create( + rewriter, op.getLoc(), Torch::ListType::get(op->getContext(), Torch::OptionalType::get(indicesType)), indicesVector); @@ -126,8 +126,8 @@ class RecomposeSelectFill_ : public OpRewritePattern { if (!matchPattern(selectOp.getDim(), m_TorchConstantInt(&dim))) return failure(); - Value noneVal = rewriter.create(op.getLoc()); - Value falseVal = rewriter.create(op.getLoc(), false); + Value noneVal = ConstantNoneOp::create(rewriter, op.getLoc()); + Value falseVal = ConstantBoolOp::create(rewriter, op.getLoc(), false); // Create IndexPut_Op // Convert indexNum to indexTensor for the selectOp @@ -137,16 +137,16 @@ class RecomposeSelectFill_ : public OpRewritePattern { selectOp.getIndex().getType()); Type emptyTensorType = selectOutTy.getWithSizesAndDtype(llvm::ArrayRef(empty), dtype); - Value indexTensor = rewriter.create( - selectOp.getLoc(), emptyTensorType, selectOp.getIndex()); + Value indexTensor = PrimNumToTensorScalarOp::create( + rewriter, selectOp.getLoc(), emptyTensorType, selectOp.getIndex()); // Create indicesVector for IndexPut_Op by TorchNone and indexTensor BaseTensorType tensorType = cast(op->getResultTypes()[0]); SmallVector indicesVector(dim, noneVal); indicesVector.push_back(indexTensor); - Value indices = rewriter.create( - op.getLoc(), + Value indices = PrimListConstructOp::create( + rewriter, op.getLoc(), Torch::ListType::get(op->getContext(), Torch::OptionalType::get(tensorType)), indicesVector); @@ -176,12 +176,13 @@ class RecomposeUnbindListUnpack : public OpRewritePattern { Value input = unbindOp.getSelf(); // add runtime.assert to check unbind's dim size == numResults - Value totalSize = rewriter.create(loc, input, dim); - Value cstNumResults = rewriter.create( - loc, rewriter.getI64IntegerAttr(op.getNumResults())); - Value eqOrNot = rewriter.create(loc, totalSize, cstNumResults); - rewriter.create( - loc, eqOrNot, + Value totalSize = AtenSizeIntOp::create(rewriter, loc, input, dim); + Value cstNumResults = ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(op.getNumResults())); + Value eqOrNot = + AtenEqIntOp::create(rewriter, loc, totalSize, cstNumResults); + RuntimeAssertOp::create( + rewriter, loc, eqOrNot, rewriter.getStringAttr("unbind's dim size should equal to " "prim.list_unpack's num results")); @@ -189,10 +190,10 @@ class RecomposeUnbindListUnpack : public OpRewritePattern { for (size_t i = 0; i < op.getNumResults(); i++) { // rewrite to select.int op auto resultTy = op.getResult(i).getType(); - auto index = rewriter.create( - op->getLoc(), rewriter.getI64IntegerAttr(i)); - auto newSelect = rewriter.create(op->getLoc(), resultTy, - input, dim, index); + auto index = Torch::ConstantIntOp::create(rewriter, op->getLoc(), + rewriter.getI64IntegerAttr(i)); + auto newSelect = AtenSelectIntOp::create(rewriter, op->getLoc(), resultTy, + input, dim, index); slices.push_back(newSelect); } rewriter.replaceOp(op, slices); @@ -227,16 +228,16 @@ class RecomposeUnbindGetItem : public OpRewritePattern { Value input = unbind.getSelf(); // add runtime.assert to check: index - Value totalSize = rewriter.create(loc, input, dim); - Value ltOrNot = rewriter.create(loc, op.getIdx(), totalSize); - rewriter.create( - loc, ltOrNot, + Value totalSize = AtenSizeIntOp::create(rewriter, loc, input, dim); + Value ltOrNot = AtenLtIntOp::create(rewriter, loc, op.getIdx(), totalSize); + RuntimeAssertOp::create( + rewriter, loc, ltOrNot, rewriter.getStringAttr("index should less than unbind's dim size")); // rewrite to slice op auto resultTy = op.getResult().getType(); - Value newSelect = rewriter.create(loc, resultTy, input, - dim, op.getIdx()); + Value newSelect = AtenSelectIntOp::create(rewriter, loc, resultTy, input, + dim, op.getIdx()); rewriter.replaceOp(op, newSelect); if (unbind.getResult().use_empty()) rewriter.eraseOp(unbind); @@ -277,23 +278,24 @@ class RecomposeSplitTensorGetItem Value dim = splitTensorOp.getDim(); // add runtime.assert to check rank constraint: index < split_result_size - Value totalSize = rewriter.create(loc, input, dim); + Value totalSize = AtenSizeIntOp::create(rewriter, loc, input, dim); Value splitResultSize = getIntCeilDiv(rewriter, loc, totalSize, splitTensorOp.getSplitSize()); Value ltOrNot = - rewriter.create(loc, op.getIdx(), splitResultSize); - rewriter.create( - loc, ltOrNot, + AtenLtIntOp::create(rewriter, loc, op.getIdx(), splitResultSize); + RuntimeAssertOp::create( + rewriter, loc, ltOrNot, rewriter.getStringAttr("index should less than split_result_size")); Value step = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); - Value start = rewriter.create( - loc, rewriter.getI64IntegerAttr(index * splitSize)); - Value end = rewriter.create( - loc, rewriter.getI64IntegerAttr(index * splitSize + splitSize)); - Value sliceTensorOp = rewriter.create( - loc, op.getResult().getType(), input, dim, start, end, step); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); + Value start = ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(index * splitSize)); + Value end = ConstantIntOp::create( + rewriter, loc, + rewriter.getI64IntegerAttr(index * splitSize + splitSize)); + Value sliceTensorOp = AtenSliceTensorOp::create( + rewriter, loc, op.getResult().getType(), input, dim, start, end, step); rewriter.replaceOp(op, sliceTensorOp); if (splitTensorOp.getResult().use_empty()) rewriter.eraseOp(splitTensorOp); @@ -327,30 +329,30 @@ class RecomposeSplitTensorListUnpack Value dim = splitTensorOp.getDim(); // add runtime.assert to check rank constraint - Value totalSize = rewriter.create(loc, input, dim); - Value cstNumResults = rewriter.create( - loc, rewriter.getI64IntegerAttr(op.getNumResults())); + Value totalSize = AtenSizeIntOp::create(rewriter, loc, input, dim); + Value cstNumResults = ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(op.getNumResults())); Value cstOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); // assert: numResults == floordiv(totalSize + splitSize - 1, splitSize) Value splitResultSize = getIntCeilDiv(rewriter, loc, totalSize, splitTensorOp.getSplitSize()); Value eqOrNot = - rewriter.create(loc, splitResultSize, cstNumResults); - rewriter.create( - loc, eqOrNot, + AtenEqIntOp::create(rewriter, loc, splitResultSize, cstNumResults); + RuntimeAssertOp::create( + rewriter, loc, eqOrNot, rewriter.getStringAttr("numResults should equal to floordiv(totalSize " "+ splitSize - 1, splitSize)")); SmallVector slices; for (size_t i = 0; i < op.getNumResults(); i++) { auto resultTy = op.getResult(i).getType(); - auto start = rewriter.create( - loc, rewriter.getI64IntegerAttr(i * splitSize)); - auto end = rewriter.create( - loc, rewriter.getI64IntegerAttr((i + 1) * splitSize)); - Value sliceTensorOp = rewriter.create( - loc, resultTy, input, dim, start, end, /*step=*/cstOne); + auto start = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(i * splitSize)); + auto end = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr((i + 1) * splitSize)); + Value sliceTensorOp = AtenSliceTensorOp::create( + rewriter, loc, resultTy, input, dim, start, end, /*step=*/cstOne); slices.push_back(sliceTensorOp); } rewriter.replaceOp(op, slices); @@ -401,15 +403,15 @@ class RecomposeSplitWithSizesGetItem Value dim = splitWithSizesOp.getDim(); // add runtime.assert to check dimension constraint - Value totalSize = rewriter.create(loc, input, dim); + Value totalSize = AtenSizeIntOp::create(rewriter, loc, input, dim); int64_t sumSplitSize = std::accumulate(splitSizes.begin(), splitSizes.end(), 0); - Value cstSumSplitSize = rewriter.create( - loc, rewriter.getI64IntegerAttr(sumSplitSize)); + Value cstSumSplitSize = ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(sumSplitSize)); Value eqOrNot = - rewriter.create(loc, totalSize, cstSumSplitSize); - rewriter.create( - loc, eqOrNot, + AtenEqIntOp::create(rewriter, loc, totalSize, cstSumSplitSize); + RuntimeAssertOp::create( + rewriter, loc, eqOrNot, rewriter.getStringAttr("split dim must be sum of split_sizes")); // replace with AtenSliceTensorOp @@ -418,13 +420,14 @@ class RecomposeSplitWithSizesGetItem boundaryOfSliceOp[i] = boundaryOfSliceOp[i - 1] + splitSizes[i - 1]; } Value cstOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); - auto start = rewriter.create( - loc, rewriter.getI64IntegerAttr(boundaryOfSliceOp[index])); - auto end = rewriter.create( - loc, rewriter.getI64IntegerAttr(boundaryOfSliceOp[index + 1])); - Value slice = rewriter.create( - loc, op.getType(), input, dim, start, end, /*step=*/cstOne); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); + auto start = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(boundaryOfSliceOp[index])); + auto end = Torch::ConstantIntOp::create( + rewriter, loc, + rewriter.getI64IntegerAttr(boundaryOfSliceOp[index + 1])); + Value slice = AtenSliceTensorOp::create(rewriter, loc, op.getType(), input, + dim, start, end, /*step=*/cstOne); rewriter.replaceOp(op, slice); // erase splitOp if no user left if (splitWithSizesOp.getResult().use_empty()) @@ -476,15 +479,15 @@ class RecomposeSplitWithSizesListUnpack Value dim = splitOp.getDim(); // add runtime.assert to check rank constraint - Value totalSize = rewriter.create(loc, input, dim); + Value totalSize = AtenSizeIntOp::create(rewriter, loc, input, dim); int64_t sumSplitSize = std::accumulate(splitSizes.begin(), splitSizes.end(), 0); - Value cstSumSplitSize = rewriter.create( - loc, rewriter.getI64IntegerAttr(sumSplitSize)); + Value cstSumSplitSize = ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(sumSplitSize)); Value eqOrNot = - rewriter.create(loc, totalSize, cstSumSplitSize); - rewriter.create( - loc, eqOrNot, + AtenEqIntOp::create(rewriter, loc, totalSize, cstSumSplitSize); + RuntimeAssertOp::create( + rewriter, loc, eqOrNot, rewriter.getStringAttr("split dim must be sum of split_sizes")); // calculate slice op's lower bound and up bound @@ -494,15 +497,16 @@ class RecomposeSplitWithSizesListUnpack } SmallVector slices; Value cstOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); for (size_t i = 0; i < op.getNumResults(); i++) { auto resultTy = op.getResult(i).getType(); - auto start = rewriter.create( - loc, rewriter.getI64IntegerAttr(boundaryOfSliceOp[i])); - auto end = rewriter.create( - loc, rewriter.getI64IntegerAttr((boundaryOfSliceOp[i + 1]))); - Value sliceTensorOp = rewriter.create( - loc, resultTy, input, dim, start, end, /*step=*/cstOne); + auto start = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(boundaryOfSliceOp[i])); + auto end = Torch::ConstantIntOp::create( + rewriter, loc, + rewriter.getI64IntegerAttr((boundaryOfSliceOp[i + 1]))); + Value sliceTensorOp = AtenSliceTensorOp::create( + rewriter, loc, resultTy, input, dim, start, end, /*step=*/cstOne); slices.push_back(sliceTensorOp); } rewriter.replaceOp(op, slices); @@ -558,24 +562,27 @@ class RecomposeTensorSplitSectionsGetItem int64_t chunkSize = splitDimSize / sections; int64_t remain = splitDimSize % sections; Value cstOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); Value result; if (index < remain) { - Value start = rewriter.create( - loc, rewriter.getI64IntegerAttr(index * (chunkSize + 1))); - Value end = rewriter.create( - loc, rewriter.getI64IntegerAttr((index + 1) * (chunkSize + 1))); - result = rewriter.create(loc, op.getType(), input, dim, - start, end, - /*step=*/cstOne); + Value start = ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(index * (chunkSize + 1))); + Value end = ConstantIntOp::create( + rewriter, loc, + rewriter.getI64IntegerAttr((index + 1) * (chunkSize + 1))); + result = AtenSliceTensorOp::create(rewriter, loc, op.getType(), input, + dim, start, end, + /*step=*/cstOne); } else { - Value start = rewriter.create( - loc, rewriter.getI64IntegerAttr(index * chunkSize + remain)); - Value end = rewriter.create( - loc, rewriter.getI64IntegerAttr((index + 1) * chunkSize + remain)); - result = rewriter.create(loc, op.getType(), input, dim, - start, end, - /*step=*/cstOne); + Value start = ConstantIntOp::create( + rewriter, loc, + rewriter.getI64IntegerAttr(index * chunkSize + remain)); + Value end = ConstantIntOp::create( + rewriter, loc, + rewriter.getI64IntegerAttr((index + 1) * chunkSize + remain)); + result = AtenSliceTensorOp::create(rewriter, loc, op.getType(), input, + dim, start, end, + /*step=*/cstOne); } rewriter.replaceOp(op, result); // erase AtenTensorSplitSectionsOp if no user left @@ -625,25 +632,27 @@ class RecomposeTensorSplitSectionsListUnpack int64_t chunkSize = splitDimSize / sections; int64_t remain = splitDimSize % sections; Value cstOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); SmallVector results; for (int64_t i = 0; i < sections; i++) { if (i < remain) { - Value start = rewriter.create( - loc, rewriter.getI64IntegerAttr(i * (chunkSize + 1))); - Value end = rewriter.create( - loc, rewriter.getI64IntegerAttr((i + 1) * (chunkSize + 1))); - Value slice = rewriter.create( - loc, op.getResult(i).getType(), input, dim, start, end, + Value start = ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(i * (chunkSize + 1))); + Value end = ConstantIntOp::create( + rewriter, loc, + rewriter.getI64IntegerAttr((i + 1) * (chunkSize + 1))); + Value slice = AtenSliceTensorOp::create( + rewriter, loc, op.getResult(i).getType(), input, dim, start, end, /*step=*/cstOne); results.push_back(slice); } else { - Value start = rewriter.create( - loc, rewriter.getI64IntegerAttr(i * chunkSize + remain)); - Value end = rewriter.create( - loc, rewriter.getI64IntegerAttr((i + 1) * chunkSize + remain)); - Value slice = rewriter.create( - loc, op.getResult(i).getType(), input, dim, start, end, + Value start = ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(i * chunkSize + remain)); + Value end = ConstantIntOp::create( + rewriter, loc, + rewriter.getI64IntegerAttr((i + 1) * chunkSize + remain)); + Value slice = AtenSliceTensorOp::create( + rewriter, loc, op.getResult(i).getType(), input, dim, start, end, /*step=*/cstOne); results.push_back(slice); } @@ -672,42 +681,42 @@ class RecomposeChunkListUnpack : public OpRewritePattern { Value input = chunkOp.getSelf(); Value chunks = chunkOp.getChunks(); Location loc = chunkOp.getLoc(); - Value totalSize = rewriter.create(loc, input, dim); + Value totalSize = Torch::AtenSizeIntOp::create(rewriter, loc, input, dim); // chunkSize = floordiv(totalSize + chunks - 1, chunks) Value chunkSize = getIntCeilDiv(rewriter, loc, totalSize, chunks); // add runtime.assert to check floordiv(totalSize + chunkSize - 1, // chunkSize) == NumResults - Value cstNumResults = rewriter.create( - loc, rewriter.getI64IntegerAttr(op.getNumResults())); + Value cstNumResults = ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(op.getNumResults())); Value realChunks = getIntCeilDiv(rewriter, loc, totalSize, chunkSize); Value eqOrNot = - rewriter.create(loc, realChunks, cstNumResults); - rewriter.create( - loc, eqOrNot, + AtenEqIntOp::create(rewriter, loc, realChunks, cstNumResults); + RuntimeAssertOp::create( + rewriter, loc, eqOrNot, rewriter.getStringAttr( "chunks should equal to prim.list_unpack's num results")); Value cstOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); SmallVector slices; for (size_t i = 0; i < op.getNumResults(); i++) { // rewrite to slice op with // start = chunkSize * i, // end = lastIndex ? totalSize : chunkSize * (i+1) auto resultTy = op.getResult(i).getType(); - auto index = rewriter.create( - op->getLoc(), rewriter.getI64IntegerAttr(i)); - auto start = rewriter.create(loc, index, chunkSize); + auto index = Torch::ConstantIntOp::create(rewriter, op->getLoc(), + rewriter.getI64IntegerAttr(i)); + auto start = AtenMulIntOp::create(rewriter, loc, index, chunkSize); Value end; if (i == op.getNumResults() - 1) { end = totalSize; } else { - auto nextIdx = rewriter.create(loc, index, cstOne); - end = rewriter.create(loc, nextIdx, chunkSize); + auto nextIdx = AtenAddIntOp::create(rewriter, loc, index, cstOne); + end = AtenMulIntOp::create(rewriter, loc, nextIdx, chunkSize); } - Value sliceTensorOp = rewriter.create( - loc, resultTy, input, dim, start, end, /*step=*/cstOne); + Value sliceTensorOp = AtenSliceTensorOp::create( + rewriter, loc, resultTy, input, dim, start, end, /*step=*/cstOne); slices.push_back(sliceTensorOp); } rewriter.replaceOp(op, slices); @@ -757,28 +766,29 @@ class RecomposeMeshgridIndexingListUnpack SmallVector expandShapeValues; for (int64_t i = 0; i < numTensors; i++) { expandShapeValues.push_back( - rewriter.create(loc, tensors[i])); + AtenNumelOp::create(rewriter, loc, tensors[i])); } - Value expandShapeList = rewriter.create( - loc, ListType::get(IntType::get(context)), expandShapeValues); + Value expandShapeList = PrimListConstructOp::create( + rewriter, loc, ListType::get(IntType::get(context)), expandShapeValues); SmallVector meshgrids; Value constFalse = - rewriter.create(loc, rewriter.getBoolAttr(false)); + ConstantBoolOp::create(rewriter, loc, rewriter.getBoolAttr(false)); for (auto [idx, tensor] : llvm::enumerate(tensors)) { Value constantOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1)); SmallVector tensorViewShapeValues(numTensors, constantOne); tensorViewShapeValues[idx] = expandShapeValues[idx]; - Value viewShapeList = rewriter.create( - loc, ListType::get(IntType::get(context)), tensorViewShapeValues); + Value viewShapeList = PrimListConstructOp::create( + rewriter, loc, ListType::get(IntType::get(context)), + tensorViewShapeValues); Value view = - rewriter.create(loc, baseType, tensor, viewShapeList); + AtenViewOp::create(rewriter, loc, baseType, tensor, viewShapeList); - Value expandView = rewriter.create( - loc, baseType, view, expandShapeList, constFalse); + Value expandView = AtenExpandOp::create(rewriter, loc, baseType, view, + expandShapeList, constFalse); meshgrids.push_back(expandView); } diff --git a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp index 5733b6b936c3..b84b4465eab5 100644 --- a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp +++ b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp @@ -30,11 +30,11 @@ static void createOverwriteTensorContents(PatternRewriter &rewriter, dyn_cast(overwrittenTensor.getType()) .getWithValueSemantics(); if (overwriterTensorType != overwrittenTensorType) { - overwriterTensor = rewriter.create( - loc, overwrittenTensorType, overwriterTensor); + overwriterTensor = TensorStaticInfoCastOp::create( + rewriter, loc, overwrittenTensorType, overwriterTensor); } - rewriter.create(loc, overwriterTensor, - overwrittenTensor); + OverwriteTensorContentsOp::create(rewriter, loc, overwriterTensor, + overwrittenTensor); } static Type getContainerOrTensorTypeWithValueSemantics(Type type) { @@ -93,8 +93,8 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { for (OpOperand &opOperand : op->getOpOperands()) { Type operandType = opOperand.get().getType(); if (isa(operandType)) { - opOperand.set(rewriter.create(op->getLoc(), - opOperand.get())); + opOperand.set(CopyToValueTensorOp::create(rewriter, op->getLoc(), + opOperand.get())); } else if (auto listType = dyn_cast(operandType)) { if (!(isa(listType.getContainedType()) || isa(listType.getContainedType()))) @@ -130,8 +130,8 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { auto newListElements = llvm::to_vector(llvm::map_range( listConstruct.getElements(), [&](Value tensor) -> Value { if (isa(tensor.getType())) { - return rewriter.create(op->getLoc(), - tensor); + return CopyToValueTensorOp::create(rewriter, op->getLoc(), + tensor); } return tensor; })); @@ -142,8 +142,8 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { return rewriter.notifyMatchFailure( op, "Unable to convert list type to value semantics."); } - opOperand.set(rewriter.create( - op->getLoc(), newListType, newListElements)); + opOperand.set(PrimListConstructOp::create( + rewriter, op->getLoc(), newListType, newListElements)); } else if (auto optionalType = dyn_cast(operandType)) { // TODO: A more general way to handle the optional type is to // introduce a `copy.to_optional_vtensor` op. @@ -162,11 +162,11 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { if (!isa(derefine.getOperand().getType())) continue; - auto newOperand = rewriter.create( - op->getLoc(), derefine.getOperand()); - opOperand.set(rewriter.create( - op->getLoc(), Torch::OptionalType::get(newOperand.getType()), - newOperand)); + auto newOperand = CopyToValueTensorOp::create(rewriter, op->getLoc(), + derefine.getOperand()); + opOperand.set(DerefineOp::create( + rewriter, op->getLoc(), + Torch::OptionalType::get(newOperand.getType()), newOperand)); } } // Convert all results. @@ -177,7 +177,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { continue; result.setType(tensorType.getWithValueSemantics()); auto nonValueTensor = - rewriter.create(op->getLoc(), result); + CopyToNonValueTensorOp::create(rewriter, op->getLoc(), result); result.replaceAllUsesExcept(nonValueTensor, nonValueTensor); } rewriter.finalizeOpModification(op); @@ -247,11 +247,11 @@ void TorchMatchSpecializedBackendOp::populateSpecializedConversions( oldOperands[0], oldOperands[1], oldOperands[2], oldOperands[5], oldOperands[3], oldOperands[4], oldOperands[6]}; Value enableGQA = - rewriter.create(op->getLoc(), false); + ConstantBoolOp::create(rewriter, op->getLoc(), false); newOperands.push_back(enableGQA); - auto newOp = rewriter.create( - op.getLoc(), op->getResultTypes()[0], newOperands, + auto newOp = Torch::AtenScaledDotProductAttentionOp::create( + rewriter, op.getLoc(), op->getResultTypes()[0], newOperands, op->getAttrs()); rewriter.replaceOp(op, {newOp.getResult(), nullptr}); return success(); @@ -270,10 +270,10 @@ namespace { // int(ceil((end - start) / step)) Value calculateArangeResultNumElements(PatternRewriter &rewriter, Location loc, Value start, Value end, Value step) { - Value sub = rewriter.create( - loc, Torch::NumberType::get(rewriter.getContext()), end, start); - Value div = rewriter.create(loc, sub, step); - return rewriter.create(loc, div); + Value sub = AtenSubOp::create( + rewriter, loc, Torch::NumberType::get(rewriter.getContext()), end, start); + Value div = AtenDivOp::create(rewriter, loc, sub, step); + return AtenCeilFloatOp::create(rewriter, loc, div); } class ReduceNonValueSemanticOps : public RewritePattern { @@ -285,10 +285,10 @@ class ReduceNonValueSemanticOps : public RewritePattern { Location loc = op->getLoc(); MLIRContext *ctx = op->getContext(); if (isa(op)) { - Operation *newOp = rewriter.create( - loc, op->getResultTypes(), op->getOperands()); + Operation *newOp = ValsemVariantAtenBernoulliFloatOp::create( + rewriter, loc, op->getResultTypes(), op->getOperands()); auto tensor = - rewriter.create(loc, newOp->getResult(0)); + CopyToValueTensorOp::create(rewriter, loc, newOp->getResult(0)); createOverwriteTensorContents(rewriter, loc, tensor, op->getOperand(0)); rewriter.replaceOp(op, op->getOperand(0)); return success(); @@ -305,26 +305,27 @@ class ReduceNonValueSemanticOps : public RewritePattern { // `y = torch.arange(13, out=x)` Value resultNumElements = calculateArangeResultNumElements(rewriter, loc, start, end, step); - Value outNumElements = rewriter.create(loc, out); + Value outNumElements = AtenNumelOp::create(rewriter, loc, out); Value eqOrNot = - rewriter.create(loc, resultNumElements, outNumElements); - rewriter.create( - loc, eqOrNot, + AtenEqIntOp::create(rewriter, loc, resultNumElements, outNumElements); + RuntimeAssertOp::create( + rewriter, loc, eqOrNot, rewriter.getStringAttr("`out` tensor should have the same " "num_elements with result tenosr")); - auto dtype = rewriter.create(loc, out); - auto device = rewriter.create(loc, out); - auto shape = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(ctx)), out); - auto none = rewriter.create(loc); - Value newArange = rewriter.create( - loc, arangeOutOp.getResult().getType(), start, end, step, dtype, + auto dtype = PrimDtypeOp::create(rewriter, loc, out); + auto device = PrimDeviceOp::create(rewriter, loc, out); + auto shape = AtenSizeOp::create( + rewriter, loc, Torch::ListType::get(Torch::IntType::get(ctx)), out); + auto none = ConstantNoneOp::create(rewriter, loc); + Value newArange = AtenArangeStartStepOp::create( + rewriter, loc, arangeOutOp.getResult().getType(), start, end, step, + dtype, /*layout=*/none, device, /*pin_memory=*/none); - Value reshape = rewriter.create( - loc, arangeOutOp.getResult().getType(), newArange, shape); + Value reshape = AtenReshapeOp::create( + rewriter, loc, arangeOutOp.getResult().getType(), newArange, shape); - auto vtensor = rewriter.create(loc, reshape); + auto vtensor = CopyToValueTensorOp::create(rewriter, loc, reshape); createOverwriteTensorContents(rewriter, loc, vtensor, out); rewriter.replaceOp(arangeOutOp, out); return success(); @@ -370,14 +371,16 @@ class ReduceTrailingUnderscoreInplaceVariant : public RewritePattern { // b = torch.randn(3, 3) # float32 // a += b # i.e. torch.ops.aten.add_(a, b), result is float16 // c = a + b # i.e. torch.ops.aten.add(a, b), result is float32 - Value none = rewriter.create(op->getLoc()); - Value cstFalse = rewriter.create(op->getLoc(), false); - auto aDtype = rewriter.create(op->getLoc(), op->getOperand(0)); - auto toDtype = rewriter.create( - op->getLoc(), newOp->getResult(0).getType(), newOp->getResult(0), - aDtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + Value none = ConstantNoneOp::create(rewriter, op->getLoc()); + Value cstFalse = ConstantBoolOp::create(rewriter, op->getLoc(), false); + auto aDtype = + PrimDtypeOp::create(rewriter, op->getLoc(), op->getOperand(0)); + auto toDtype = AtenToDtypeOp::create( + rewriter, op->getLoc(), newOp->getResult(0).getType(), + newOp->getResult(0), aDtype, /*non_blocking=*/cstFalse, + /*copy=*/cstFalse, /*memory_format=*/none); - auto tensor = rewriter.create(op->getLoc(), toDtype); + auto tensor = CopyToValueTensorOp::create(rewriter, op->getLoc(), toDtype); createOverwriteTensorContents(rewriter, op->getLoc(), tensor, op->getOperand(0)); rewriter.replaceOp(op, op->getOperand(0)); @@ -391,7 +394,7 @@ static LogicalResult reduceNonValueTensorLiteralOpToValueTensorLiteralOp(NonValueTensorLiteralOp op, PatternRewriter &rewriter) { Value valueTensor = - rewriter.create(op->getLoc(), op.getValue()); + ValueTensorLiteralOp::create(rewriter, op->getLoc(), op.getValue()); Value tensor = copyTensorToType(rewriter, op->getLoc(), op.getType(), valueTensor); rewriter.replaceOp(op, {tensor}); diff --git a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp index cd6126aa4da5..ff051e343732 100644 --- a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp @@ -35,9 +35,9 @@ static Operation *createCalculateOp(OpBuilder &b, Location loc, TypeRange resultTypes, LibraryFunctionKind libFuncKind) { if (libFuncKind == LibraryFunctionKind::ShapeFunction) - return b.create(loc, resultTypes); + return ShapeCalculateOp::create(b, loc, resultTypes); else if (libFuncKind == LibraryFunctionKind::DtypeFunction) - return b.create(loc, resultTypes); + return DtypeCalculateOp::create(b, loc, resultTypes); llvm_unreachable( "`createCalculateOp` called with an unsupported `LibraryFunctionKind`"); } @@ -46,9 +46,9 @@ static Operation *createCalculateYieldOp(OpBuilder &b, Location loc, ValueRange results, LibraryFunctionKind libFuncKind) { if (libFuncKind == LibraryFunctionKind::ShapeFunction) - return b.create(loc, results); + return ShapeCalculateYieldOp::create(b, loc, results); else if (libFuncKind == LibraryFunctionKind::DtypeFunction) - return b.create(loc, results); + return DtypeCalculateYieldOp::create(b, loc, results); llvm_unreachable("`createCalculateYieldOp` called with an unsupported " "`LibraryFunctionKind`"); } @@ -58,9 +58,9 @@ createCalculateYieldCalculationOp(OpBuilder &b, Location loc, ValueRange results, LibraryFunctionKind libFuncKind) { if (libFuncKind == LibraryFunctionKind::ShapeFunction) - return b.create(loc, results); + return ShapeCalculateYieldShapesOp::create(b, loc, results); else if (libFuncKind == LibraryFunctionKind::DtypeFunction) - return b.create(loc, results); + return DtypeCalculateYieldDtypesOp::create(b, loc, results); llvm_unreachable("`createCalculateYieldCalculationOp` called with an " "unsupported `LibraryFunctionKind`"); } @@ -110,7 +110,7 @@ LogicalResult Torch::wrapWithCalculateOpIfLibraryFunctionAvailable( libFuncArgsBuilder(b, loc, op->getOperands(), libFunc); if (failed(libFuncArgs)) return failure(); - auto call = b.create(loc, libFunc, *libFuncArgs); + auto call = mlir::func::CallOp::create(b, loc, libFunc, *libFuncArgs); // Python models multiple results with a tuple, so we need to unpack it // if the op has multiple results. @@ -119,8 +119,8 @@ LogicalResult Torch::wrapWithCalculateOpIfLibraryFunctionAvailable( "Multiple results are packed in a tuple in Python!"); Value result = call.getResult(0); if (auto tupleType = dyn_cast(result.getType())) { - auto unpack = b.create( - loc, tupleType.getContainedTypes(), result); + auto unpack = PrimTupleUnpackOp::create( + b, loc, tupleType.getContainedTypes(), result); llvm::append_range(unpackedResults, unpack.getResults()); } else { unpackedResults.push_back(result); @@ -175,14 +175,14 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand, // compile a function with Generator type arguments. // Ignoring that hack, this is a correct handling of Any type should we need // to actually support it in the future. - return b.create(loc, desiredType, operand).getResult(); + return DerefineOp::create(b, loc, desiredType, operand).getResult(); } // The type `!torch.number` can be an `int`, `float`, or `complex`. // TODO: Add a new type `Torch::ComplexType` to handle the complex case. if (isa(desiredType) && isa(operandType)) { - return b.create(loc, desiredType, operand).getResult(); + return DerefineOp::create(b, loc, desiredType, operand).getResult(); } // !torch.union is the type used for optional @@ -194,7 +194,7 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand, return isa( containedType); })) - return b.create(loc, desiredType, operand).getResult(); + return DerefineOp::create(b, loc, desiredType, operand).getResult(); } // Operands with type `!torch.none` correspond to library function inputs with @@ -203,7 +203,7 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand, if (isa(operandType)) { assert(!isa(desiredType) && "Don't expect library functions to have NoneType parameters"); - return b.create(loc, desiredType, operand).getResult(); + return DerefineOp::create(b, loc, desiredType, operand).getResult(); } // To keep things simple in shape functions, `Scalar` inputs are considered @@ -213,7 +213,7 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand, // into the IR. if (isa(operandType) && isa(desiredType)) { - return b.create(loc, desiredType, operand).getResult(); + return AtenFloatScalarOp::create(b, loc, desiredType, operand).getResult(); } // If the operand type is statically !torch.optional, then we need to do @@ -230,25 +230,25 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand, // return derefine(None) // else: // return adjust(unchecked_cast(optional)) - auto none = b.create(loc); - auto isNone = b.create(loc, operand, none); - auto primIf = b.create(loc, desiredType, isNone); + auto none = ConstantNoneOp::create(b, loc); + auto isNone = Aten__Is__Op::create(b, loc, operand, none); + auto primIf = PrimIfOp::create(b, loc, desiredType, isNone); { Region &thenRegion = primIf.getThenRegion(); b.createBlock(&thenRegion, thenRegion.end()); - auto derefineNone = b.create(loc, desiredType, none); - b.create(loc, ValueRange{derefineNone}); + auto derefineNone = DerefineOp::create(b, loc, desiredType, none); + PrimIfYieldOp::create(b, loc, ValueRange{derefineNone}); } { Region &elseRegion = primIf.getElseRegion(); b.createBlock(&elseRegion, elseRegion.end()); - auto downcasted = b.create( - loc, operandOptionalType.getContainedType(), operand); + auto downcasted = PrimUncheckedCastOp::create( + b, loc, operandOptionalType.getContainedType(), operand); FailureOr adjusted = adjustFunctionArg( b, loc, downcasted, desiredType, baseTransformation); if (failed(adjusted)) return failure(); - b.create(loc, *adjusted); + PrimIfYieldOp::create(b, loc, *adjusted); } b.setInsertionPointAfter(primIf); return primIf.getResult(0); @@ -264,7 +264,7 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand, baseTransformation); if (failed(adjusted)) return failure(); - return b.create(loc, desiredType, *adjusted).getResult(); + return DerefineOp::create(b, loc, desiredType, *adjusted).getResult(); } if (auto desiredListType = dyn_cast(desiredType)) { @@ -277,13 +277,13 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand, // return adjusted_list auto providedType = cast(operand.getType()); Value adjustedList = - b.create(loc, desiredListType, ValueRange({})); + PrimListConstructOp::create(b, loc, desiredListType, ValueRange({})); // Create a for-like PrimLoopOp. - Value maxTripCount = b.create(loc, operand); - Value cTrue = b.create(loc, true); - auto loop = b.create(loc, TypeRange({}), maxTripCount, - /*initialCondition=*/cTrue, - /*iterArgsInit=*/ValueRange({})); + Value maxTripCount = AtenLenTOp::create(b, loc, operand); + Value cTrue = Torch::ConstantBoolOp::create(b, loc, true); + auto loop = PrimLoopOp::create(b, loc, TypeRange({}), maxTripCount, + /*initialCondition=*/cTrue, + /*iterArgsInit=*/ValueRange({})); // Create the loop body. { @@ -292,17 +292,17 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand, b.createBlock(&loop.getRegion(), loop.getRegion().begin(), TypeRange({b.getType()}), {loc}); Value iterationNumber = body->getArgument(0); - Value element = b.create( - loc, providedType.getContainedType(), operand, iterationNumber); + Value element = Aten__Getitem__TOp::create( + b, loc, providedType.getContainedType(), operand, iterationNumber); FailureOr adjustedElement = adjustFunctionArg(b, loc, element, desiredListType.getContainedType(), baseTransformation); if (failed(adjustedElement)) return failure(); - b.create(loc, adjustedList.getType(), adjustedList, - *adjustedElement); - b.create(loc, /*shouldContinue=*/cTrue, - /*iterArgs=*/ValueRange({})); + AtenAppendTOp::create(b, loc, adjustedList.getType(), adjustedList, + *adjustedElement); + PrimLoopConditionOp::create(b, loc, /*shouldContinue=*/cTrue, + /*iterArgs=*/ValueRange({})); } return adjustedList; @@ -313,7 +313,7 @@ Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand, // explanation). if (isa(desiredType) && isa(operand.getType())) { - return b.create(loc, desiredType, operand).getResult(); + return AtenFloatScalarOp::create(b, loc, desiredType, operand).getResult(); } // Pass the operand as-is. diff --git a/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp b/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp index 3e9ec336641d..790fd80a2f71 100644 --- a/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyDtypeCalculations.cpp @@ -33,11 +33,11 @@ dtypeFunctionArgsBuilder(OpBuilder &b, Location loc, isa(operand.getType())) { Type intType = Torch::IntType::get(b.getContext()); Type sizeListType = Torch::ListType::get(intType); - Value size = b.create(loc, sizeListType, operand); - Value rank = b.create(loc, intType, size); - Value dtype = b.create(loc, intType, operand); - return b.create(loc, desiredType, - ArrayRef{rank, dtype}); + Value size = AtenSizeOp::create(b, loc, sizeListType, operand); + Value rank = AtenLenTOp::create(b, loc, intType, size); + Value dtype = PrimDtypeOp::create(b, loc, intType, operand); + return PrimTupleConstructOp::create(b, loc, desiredType, + ArrayRef{rank, dtype}); } return operand; }; diff --git a/lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp b/lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp index fb9d33123a9c..4b81970909d2 100644 --- a/lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyShapeCalculations.cpp @@ -43,7 +43,7 @@ shapeFunctionArgsBuilder(OpBuilder &b, Location loc, return operand; if (isa(operand.getType()) && isa(desiredListType.getContainedType())) { - return b.create(loc, desiredType, operand); + return AtenSizeOp::create(b, loc, desiredType, operand); } return operand; }); diff --git a/lib/Dialect/Torch/Transforms/RestructureNonConstantAxes.cpp b/lib/Dialect/Torch/Transforms/RestructureNonConstantAxes.cpp index dd7e37320777..0ea79a02a799 100644 --- a/lib/Dialect/Torch/Transforms/RestructureNonConstantAxes.cpp +++ b/lib/Dialect/Torch/Transforms/RestructureNonConstantAxes.cpp @@ -88,8 +88,8 @@ class ConstantifyDimArgument : public OpRewritePattern { Type intType = rewriter.getType(); Type boolType = rewriter.getType(); auto createInt = [&](int value) { - return rewriter.create( - loc, intType, + return Torch::ConstantIntOp::create( + rewriter, loc, intType, rewriter.getIntegerAttr(rewriter.getIntegerType(64), value)); }; Value zero = createInt(0); @@ -98,23 +98,24 @@ class ConstantifyDimArgument : public OpRewritePattern { // handle when dim is a single element list bool oldDimIsList = isa(dim.getType()); if (oldDimIsList) { - Value len = rewriter.create(loc, intType, dim); + Value len = Torch::AtenLenTOp::create(rewriter, loc, intType, dim); Value dimListIsLengthOne = - rewriter.create(loc, boolType, len, one); - rewriter.create( - loc, dimListIsLengthOne, + Torch::AtenEqIntOp::create(rewriter, loc, boolType, len, one); + Torch::RuntimeAssertOp::create( + rewriter, loc, dimListIsLengthOne, rewriter.getStringAttr("RestructureNonConstantAxes does not support " "dim lists with more than one element")); - dim = rewriter.create(loc, intType, dim, zero); + dim = + Torch::Aten__Getitem__TOp::create(rewriter, loc, intType, dim, zero); } // Normalize negative dim - Value rank = rewriter.create(loc, intType, self); - Value isNegative = rewriter.create(loc, dim, zero); - Value rankOffset = rewriter.create( - loc, intType, - rewriter.create(loc, intType, isNegative), rank); - dim = rewriter.create(loc, intType, dim, rankOffset); + Value rank = Torch::AtenDimOp::create(rewriter, loc, intType, self); + Value isNegative = Torch::AtenLtIntOp::create(rewriter, loc, dim, zero); + Value rankOffset = Torch::AtenMulIntOp::create( + rewriter, loc, intType, + Torch::AtenIntBoolOp::create(rewriter, loc, intType, isNegative), rank); + dim = Torch::AtenAddIntOp::create(rewriter, loc, intType, dim, rankOffset); auto createConditionalMult = [&](Value self, Value multiplier, Value condition) { @@ -125,16 +126,17 @@ class ConstantifyDimArgument : public OpRewritePattern { // which translates to: // result = multiplier - 1 - Value result = rewriter.create( - loc, intType, multiplier, createInt(1)); + Value result = Torch::AtenSubIntOp::create(rewriter, loc, intType, + multiplier, createInt(1)); // result = result * condition - result = - rewriter.create(loc, intType, result, condition); + result = Torch::AtenMulIntOp::create(rewriter, loc, intType, result, + condition); // result = result + 1 - result = rewriter.create(loc, intType, result, - createInt(1)); + result = Torch::AtenAddIntOp::create(rewriter, loc, intType, result, + createInt(1)); // result = self * result - result = rewriter.create(loc, intType, self, result); + result = + Torch::AtenMulIntOp::create(rewriter, loc, intType, self, result); return result; }; @@ -146,29 +148,29 @@ class ConstantifyDimArgument : public OpRewritePattern { for (size_t i = 0; i < selfTy.getSizes().size(); ++i) { Value idx = createInt(i); Value size = - rewriter.create(loc, intType, self, idx); + Torch::AtenSizeIntOp::create(rewriter, loc, intType, self, idx); Value isBeforeDim = - rewriter.create(loc, boolType, idx, dim); + Torch::AtenLtIntOp::create(rewriter, loc, boolType, idx, dim); isBeforeDim = - rewriter.create(loc, intType, isBeforeDim); + Torch::AtenIntBoolOp::create(rewriter, loc, intType, isBeforeDim); Value isAfterDim = - rewriter.create(loc, boolType, idx, dim); + Torch::AtenGtIntOp::create(rewriter, loc, boolType, idx, dim); isAfterDim = - rewriter.create(loc, intType, isAfterDim); + Torch::AtenIntBoolOp::create(rewriter, loc, intType, isAfterDim); Value isEqualToDim = - rewriter.create(loc, boolType, idx, dim); + Torch::AtenEqIntOp::create(rewriter, loc, boolType, idx, dim); isEqualToDim = - rewriter.create(loc, intType, isEqualToDim); + Torch::AtenIntBoolOp::create(rewriter, loc, intType, isEqualToDim); dimSize = createConditionalMult(dimSize, size, isEqualToDim); beforeProd = createConditionalMult(beforeProd, size, isBeforeDim); afterProd = createConditionalMult(afterProd, size, isAfterDim); } - Value newShape = rewriter.create( - loc, rewriter.getType(intType), + Value newShape = Torch::PrimListConstructOp::create( + rewriter, loc, rewriter.getType(intType), ValueRange{beforeProd, dimSize, afterProd}); // Reshape input @@ -177,14 +179,15 @@ class ConstantifyDimArgument : public OpRewritePattern { Torch::kUnknownSize}, selfTy.getDtype()); Value reshapedSelf = - rewriter.create(loc, newSelfTy, self, newShape); + Torch::AtenViewOp::create(rewriter, loc, newSelfTy, self, newShape); // construct new operange range where self is replaced with reshapedSelf // tensor, and dim is replaced with 1 Value newDim; if (oldDimIsList) { - newDim = rewriter.create( - loc, rewriter.getType(intType), ValueRange{one}); + newDim = Torch::PrimListConstructOp::create( + rewriter, loc, rewriter.getType(intType), + ValueRange{one}); } else { newDim = one; } @@ -208,7 +211,7 @@ class ConstantifyDimArgument : public OpRewritePattern { resultTy.getDtype())); Value newReductionOp = - rewriter.create(loc, newResultTy, newOperands, op->getAttrs()); + SrcOp::create(rewriter, loc, newResultTy, newOperands, op->getAttrs()); // Reshape the result back to original shape ValueTensorType oldResultTy = @@ -217,10 +220,11 @@ class ConstantifyDimArgument : public OpRewritePattern { for (auto dim : oldResultTy.getSizes()) { shapeValues.push_back(createInt(dim)); } - Value originalShape = rewriter.create( - loc, rewriter.getType(intType), shapeValues); - Value result = rewriter.create( - loc, op->getResult(0).getType(), newReductionOp, originalShape); + Value originalShape = Torch::PrimListConstructOp::create( + rewriter, loc, rewriter.getType(intType), shapeValues); + Value result = + Torch::AtenViewOp::create(rewriter, loc, op->getResult(0).getType(), + newReductionOp, originalShape); rewriter.replaceOp(op, result); return success(); diff --git a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp index 94b21e99421e..d6db40d0c182 100644 --- a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp +++ b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp @@ -38,13 +38,13 @@ LogicalResult materializeFolds(ImplicitLocOpBuilder b, if (auto attr = dyn_cast(f)) { if (auto val = dyn_cast(attr)) { values.push_back( - b.create(APFloat(val.getValueAsDouble()))); + Torch::ConstantFloatOp::create(b, APFloat(val.getValueAsDouble()))); continue; } if (auto val = dyn_cast(attr)) { values.push_back( - b.create(val.getValue().getSExtValue())); + Torch::ConstantIntOp::create(b, val.getValue().getSExtValue())); continue; } } @@ -148,12 +148,12 @@ LogicalResult getListFromTensor(Value value, SmallVector &vals) { Value constructAtenTensorOpFromList(ImplicitLocOpBuilder b, mlir::Type resultTy, SmallVector &listValues) { - auto dimList = b.create( - b.getType(listValues.front().getType()), listValues); - Value cstNone = b.create(); - Value cstFalse = b.create(b.getBoolAttr(false)); - return b.create(resultTy, dimList, cstNone, cstNone, - cstFalse); + auto dimList = Torch::PrimListConstructOp::create( + b, b.getType(listValues.front().getType()), listValues); + Value cstNone = Torch::ConstantNoneOp::create(b); + Value cstFalse = Torch::ConstantBoolOp::create(b, b.getBoolAttr(false)); + return Torch::AtenTensorOp::create(b, resultTy, dimList, cstNone, cstNone, + cstFalse); } } // namespace @@ -190,12 +190,13 @@ class PropagateAtenBroadcastToPattern if (failed(materializeFolds(b, fillFold, fillVals))) return failure(); - Value size = b.create(ty.getSizes().front()); - Value sizeList = b.create( + Value size = Torch::ConstantIntOp::create(b, ty.getSizes().front()); + Value sizeList = Torch::PrimListConstructOp::create( + b, rewriter.getType(rewriter.getType()), size); - Value none = b.create(); - Value cstFalse = b.create(false); + Value none = Torch::ConstantNoneOp::create(b); + Value cstFalse = Torch::ConstantBoolOp::create(b, false); rewriter.replaceOpWithNewOp(op, ty, sizeList, fillVals.front(), none, none, none, cstFalse); return success(); @@ -220,7 +221,7 @@ class PropagateAtenShapeToTensorPattern int64_t rank = selfTy.getSizes().size(); SmallVector dims; for (int64_t i = 0; i < rank; ++i) { - auto iv = b.create(i); + auto iv = Torch::ConstantIntOp::create(b, i); dims.push_back(b.createOrFold( rewriter.getType(), self, iv)); } @@ -623,16 +624,16 @@ class PropagateAtenWhereSelfPattern : public OpRewritePattern { auto rank0BoolTy = rewriter.getType( ArrayRef({}), conditionTy.getDtype()); for (uint64_t i = 0; i < selfList.size(); i++) { - Value rank0Cond = b.create( - rank0BoolTy, conditionList[i]); + Value rank0Cond = Torch::PrimNumToTensorScalarOp::create( + b, rank0BoolTy, conditionList[i]); Value rank0Self = - b.create(rank0IntTy, selfList[i]); + Torch::PrimNumToTensorScalarOp::create(b, rank0IntTy, selfList[i]); Value rank0Other = - b.create(rank0IntTy, otherList[i]); - Value rank0Where = b.create(rank0IntTy, rank0Cond, - rank0Self, rank0Other); - whereVals.push_back( - b.create(rewriter.getType(), rank0Where)); + Torch::PrimNumToTensorScalarOp::create(b, rank0IntTy, otherList[i]); + Value rank0Where = AtenWhereSelfOp::create(b, rank0IntTy, rank0Cond, + rank0Self, rank0Other); + whereVals.push_back(AtenItemOp::create( + b, rewriter.getType(), rank0Where)); } Value result = constructAtenTensorOpFromList(b, op.getType(), whereVals); rewriter.replaceOp(op, result); @@ -1009,22 +1010,23 @@ class FoldAtenEqIntPattern : public OpRewritePattern { if (auto mulOp = op.getA().getDefiningOp()) { Value self = mulOp.getA(); Value other = mulOp.getB(); - Value selfEq = rewriter.create(op.getLoc(), self, op.getB()); + Value selfEq = + AtenEqIntOp::create(rewriter, op.getLoc(), self, op.getB()); Value otherEq = - rewriter.create(op.getLoc(), other, op.getB()); + AtenEqIntOp::create(rewriter, op.getLoc(), other, op.getB()); rewriter.replaceOpWithNewOp(op, selfEq, otherEq); return success(); } // if lhs is size.int op, assert size > 0 and replace with false. if (auto sizeOp = op.getA().getDefiningOp()) { - Value selfGtOther = rewriter.create( - op.getLoc(), op.getType(), op.getA(), op.getB()); - rewriter.create( - op.getLoc(), selfGtOther, + Value selfGtOther = AtenGtIntOp::create( + rewriter, op.getLoc(), op.getType(), op.getA(), op.getB()); + Torch::RuntimeAssertOp::create( + rewriter, op.getLoc(), selfGtOther, rewriter.getStringAttr("Expected dim size > 0.")); Value cstFalse = - rewriter.create(op.getLoc(), false); + Torch::ConstantBoolOp::create(rewriter, op.getLoc(), false); rewriter.replaceOp(op, cstFalse); return success(); } @@ -1069,16 +1071,16 @@ class FoldAtenTensorSplatPattern : public OpRewritePattern { auto loc = op.getLoc(); SmallVector sizes; for (auto size : resultTy.getSizes()) - sizes.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(size))); + sizes.push_back(Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(size))); - Value sizeList = rewriter.create( - loc, + Value sizeList = Torch::PrimListConstructOp::create( + rewriter, loc, rewriter.getType(rewriter.getType()), sizes); - Value none = rewriter.create(loc); - Value cstFalse = rewriter.create(loc, false); + Value none = Torch::ConstantNoneOp::create(rewriter, loc); + Value cstFalse = Torch::ConstantBoolOp::create(rewriter, loc, false); rewriter.replaceOpWithNewOp( op, resultTy, sizeList, elements.front(), none, none, none, cstFalse); return success(); @@ -1107,16 +1109,16 @@ class FoldAtenSqueezePattern : public OpRewritePattern { } SmallVector sizes; for (int i = 0, s = resultTy.getSizes().size(); i < s; ++i) - sizes.push_back(rewriter.create( - op.getLoc(), rewriter.getType(), + sizes.push_back(Torch::ConstantIntOp::create( + rewriter, op.getLoc(), rewriter.getType(), rewriter.getI64IntegerAttr(i))); - Value sizeList = rewriter.create( - op.getLoc(), + Value sizeList = Torch::PrimListConstructOp::create( + rewriter, op.getLoc(), rewriter.getType(rewriter.getType()), sizes); - Value none = rewriter.create(op.getLoc()); + Value none = Torch::ConstantNoneOp::create(rewriter, op.getLoc()); rewriter.replaceOpWithNewOp(op, resultTy, sizeList, atenFull.getFillValue(), none, none, none, none); @@ -1210,16 +1212,16 @@ class FoldAtenUnsqueezePattern : public OpRewritePattern { if (auto atenFull = op.getSelf().getDefiningOp()) { SmallVector sizes; for (int i = 0, s = resultTy.getSizes().size(); i < s; ++i) - sizes.push_back(rewriter.create( - op.getLoc(), rewriter.getType(), + sizes.push_back(Torch::ConstantIntOp::create( + rewriter, op.getLoc(), rewriter.getType(), rewriter.getI64IntegerAttr(i))); - Value sizeList = rewriter.create( - op.getLoc(), + Value sizeList = Torch::PrimListConstructOp::create( + rewriter, op.getLoc(), rewriter.getType(rewriter.getType()), sizes); - Value none = rewriter.create(op.getLoc()); + Value none = Torch::ConstantNoneOp::create(rewriter, op.getLoc()); rewriter.replaceOpWithNewOp(op, resultTy, sizeList, atenFull.getFillValue(), none, none, none, none); @@ -1347,7 +1349,7 @@ class CanonicalizeAtenViewPattern : public OpRewritePattern { // if input has 1 unmatched dim, and output has multiple, unflatten if (inputUnmatched == 1 && outputUnmatched > 1) { Value dimVal = - rewriter.create(op.getLoc(), leftMatchEnd); + Torch::ConstantIntOp::create(rewriter, op.getLoc(), leftMatchEnd); SmallVector unflattenSizes(viewSizes.begin() + leftMatchEnd, viewSizes.end() - rightMatchEnd); // try to convert a single dynamic size input to -1 @@ -1369,10 +1371,10 @@ class CanonicalizeAtenViewPattern : public OpRewritePattern { // if only one size is dynamic, make it -1 if (dynCount == 1) unflattenSizes[dynIdx] = - rewriter.create(op.getLoc(), -1); + Torch::ConstantIntOp::create(rewriter, op.getLoc(), -1); - Value unflattenList = rewriter.create( - op.getLoc(), op.getSize().getType(), unflattenSizes); + Value unflattenList = Torch::PrimListConstructOp::create( + rewriter, op.getLoc(), op.getSize().getType(), unflattenSizes); rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), dimVal, unflattenList); return success(); @@ -1380,10 +1382,11 @@ class CanonicalizeAtenViewPattern : public OpRewritePattern { // if multiple unmatched input dims map to one output dim, flatten if (inputUnmatched > 1 && outputUnmatched == 1) { Value startDim = - rewriter.create(op.getLoc(), leftMatchEnd); + Torch::ConstantIntOp::create(rewriter, op.getLoc(), leftMatchEnd); // note: flatten end is inclusive for some reason. int64_t endInt = inRank - rightMatchEnd - 1; - Value endDim = rewriter.create(op.getLoc(), endInt); + Value endDim = + Torch::ConstantIntOp::create(rewriter, op.getLoc(), endInt); rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), startDim, endDim); return success(); diff --git a/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp b/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp index d599fd5369f4..6c3944ea2c8b 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyAbstractInterpCalculationsUtils.cpp @@ -55,8 +55,9 @@ class FullyUnrollPrimLoopOp : public OpRewritePattern { SmallVector indices; for (int64_t i = 0; i < maxTripCount; i++) { // TODO: Add convenience builder. - indices.push_back(rewriter.create( - loc, rewriter.getIntegerAttr(IntegerType::get(context, 64), i))); + indices.push_back(ConstantIntOp::create( + rewriter, loc, + rewriter.getIntegerAttr(IntegerType::get(context, 64), i))); } Block *beforeBlock = op->getBlock(); Block *afterBlock = rewriter.splitBlock(op->getBlock(), op->getIterator()); @@ -211,30 +212,33 @@ class AbstractlyInterpretListOpsWithinABlock return rewriter.notifyMatchFailure(op, "No new literal created"); // Rewrite all users to use the appropriate list literals. - Value latestLiteral = rewriter.create( - op->getLoc(), op.getType(), op->getOperands()); + Value latestLiteral = PrimListConstructOp::create( + rewriter, op->getLoc(), op.getType(), op->getOperands()); int nextLiteral = 0; for (Operation *user : usersToInterpret) { if (auto append = dyn_cast(user)) { rewriter.setInsertionPoint(append); - latestLiteral = rewriter.create( - append->getLoc(), op.getType(), listLiterals[nextLiteral++]); + latestLiteral = PrimListConstructOp::create( + rewriter, append->getLoc(), op.getType(), + listLiterals[nextLiteral++]); if (append.getSelf() == op) rewriter.eraseOp(append); continue; } if (auto insert = dyn_cast(user)) { rewriter.setInsertionPoint(insert); - latestLiteral = rewriter.create( - insert->getLoc(), op.getType(), listLiterals[nextLiteral++]); + latestLiteral = PrimListConstructOp::create( + rewriter, insert->getLoc(), op.getType(), + listLiterals[nextLiteral++]); if (insert.getSelf() == op) rewriter.eraseOp(insert); continue; } if (auto setItem = dyn_cast(user)) { rewriter.setInsertionPoint(setItem); - latestLiteral = rewriter.create( - setItem->getLoc(), op.getType(), listLiterals[nextLiteral++]); + latestLiteral = PrimListConstructOp::create( + rewriter, setItem->getLoc(), op.getType(), + listLiterals[nextLiteral++]); if (setItem.getL() == op) rewriter.eraseOp(setItem); continue; @@ -295,11 +299,11 @@ LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp, if (!originalTypedValue) { rewriter.setInsertionPointAfter(calculateOp); if (isa(originalResultType)) { - originalTypedValue = rewriter.create( - loc, originalResultType, result); + originalTypedValue = TensorStaticInfoCastOp::create( + rewriter, loc, originalResultType, result); } else if (isa(originalResultType)) { originalTypedValue = - rewriter.create(loc, originalResultType, result); + DerefineOp::create(rewriter, loc, originalResultType, result); } else { return rewriter.notifyMatchFailure( calculateOp, "Unimplemented: Expected result type to " @@ -326,10 +330,10 @@ LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp, rewriter.setInsertionPoint(yieldValues); if (isa(updatedType)) { newYieldedValue = - rewriter.create(loc, updatedType, def); + TensorStaticInfoCastOp::create(rewriter, loc, updatedType, def); } else { newYieldedValue = - rewriter.create(loc, updatedType, def); + PrimUncheckedCastOp::create(rewriter, loc, updatedType, def); } } use.set(newYieldedValue); diff --git a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp index b690be48662a..54a9fb07d72b 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp @@ -33,13 +33,14 @@ class DecomposeAtenSizeOp : public OpRewritePattern { int64_t rank = tensorType.getSizes().size(); SmallVector sizes; for (int i = 0; i < rank; i++) { - Value dim = rewriter.create( - loc, rewriter.getI64IntegerAttr(i)); - sizes.push_back(rewriter.create(loc, self, dim)); + Value dim = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(i)); + sizes.push_back(AtenSizeIntOp::create(rewriter, loc, self, dim)); } - Value sizeList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(context)), sizes); + Value sizeList = PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(Torch::IntType::get(context)), + sizes); rewriter.replaceOp(op, sizeList); return success(); } @@ -90,7 +91,7 @@ class InferTensorOp : public OpRewritePattern { if (!originalTypedValue) { rewriter.setInsertionPointAfter(op); originalTypedValue = - rewriter.create(loc, resultType, result); + TensorStaticInfoCastOp::create(rewriter, loc, resultType, result); } use.set(originalTypedValue); } diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 66c7c2ef2ec6..c95837c44261 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -41,11 +41,11 @@ Value Torch::toIntListConstruct(PatternRewriter &rewriter, Location loc, ArrayRef cstInput) { SmallVector cstValues; for (int64_t i : cstInput) { - cstValues.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(i))); + cstValues.push_back(Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(i))); } - return rewriter.create( - loc, Torch::ListType::get(IntType::get(rewriter.getContext())), + return Torch::PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(IntType::get(rewriter.getContext())), cstValues); } @@ -233,8 +233,8 @@ Type Torch::getBuiltInTypeForTorchScalar(Type type) { Value Torch::getDtypeIntValueForType(PatternRewriter &rewriter, Location loc, Type dtype) { int intType = (int)getScalarTypeForType(dtype); - return rewriter.create(loc, - rewriter.getI64IntegerAttr(intType)); + return ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(intType)); } template @@ -281,10 +281,11 @@ Value Torch::convertTensorToDtype(PatternRewriter &rewriter, Location loc, // `convertIntVal` contains the corresponding integer for the dtype which is // used by the aten.to.dtype op. Value convertIntVal = getDtypeIntValueForType(rewriter, loc, dtype); - Value falseVal = rewriter.create(loc, false); - Value noneVal = rewriter.create(loc); - Value converted = rewriter.create( - loc, newType, input, convertIntVal, falseVal, falseVal, noneVal); + Value falseVal = ConstantBoolOp::create(rewriter, loc, false); + Value noneVal = ConstantNoneOp::create(rewriter, loc); + Value converted = + AtenToDtypeOp::create(rewriter, loc, newType, input, convertIntVal, + falseVal, falseVal, noneVal); return converted; } @@ -337,13 +338,13 @@ Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, // Creating constants satisfying backend contract. if (dtype.isInteger(64) || dtype.isInteger(32) || dtype.isInteger(16) || dtype.isInteger(8) || dtype.isInteger(1)) - return rewriter.create( - loc, rewriter.getI64IntegerAttr((int64_t)value)); + return ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr((int64_t)value)); if (dtype.isF64() || dtype.isF32() || dtype.isF16() || dtype.isBF16() || isa(dtype)) - return rewriter.create(loc, - rewriter.getF64FloatAttr(value)); + return ConstantFloatOp::create(rewriter, loc, + rewriter.getF64FloatAttr(value)); llvm::report_fatal_error( "unhandled type for getConstantWithGivenDtypeAndValue"); } @@ -402,7 +403,7 @@ Value Torch::getTensorDimSize(PatternRewriter &rewriter, Value tensor, int64_t dim) { auto loc = tensor.getLoc(); auto dimVal = - rewriter.create(loc, rewriter.getI64IntegerAttr(dim)); + ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(dim)); // Use 'createOrFold' instead of 'create': // If the dimension is a constant, then the AtenSizeIntOp is folded to a // ContantIntOp. @@ -428,19 +429,19 @@ FailureOr Torch::squeezeTensor(PatternRewriter &rewriter, Operation *op, Type squeezedType = inputType.getWithSizesAndDtype(inputShape, inputType.getOptionalDtype()); - Value cstDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(dim)); + Value cstDim = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(dim)); // Adding a check to verify if the dimension to be squeezed has size 1 or not. - Value cstOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); - Value dimSize = rewriter.create(loc, input, cstDim); - Value cmp = rewriter.create(loc, dimSize, cstOne); - rewriter.create( - loc, cmp, + Value cstOne = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(1)); + Value dimSize = AtenSizeIntOp::create(rewriter, loc, input, cstDim); + Value cmp = Torch::AtenEqIntOp::create(rewriter, loc, dimSize, cstOne); + Torch::RuntimeAssertOp::create( + rewriter, loc, cmp, "squeeze operation possible for dim only when input_shape[dim] == 1."); Value result = - rewriter.create(loc, squeezedType, input, cstDim); + AtenSqueezeDimOp::create(rewriter, loc, squeezedType, input, cstDim); return result; } @@ -475,8 +476,8 @@ FailureOr Torch::unsqueezeTensor(PatternRewriter &rewriter, } Type unsqueezedType = inputType.getWithSizesAndDtypeAndSparsity( unsqueezedShape, inputType.getOptionalDtype(), enc.value()); - Value unsqueezed = rewriter.create( - op->getLoc(), unsqueezedType, input, dim); + Value unsqueezed = AtenUnsqueezeOp::create(rewriter, op->getLoc(), + unsqueezedType, input, dim); return unsqueezed; } @@ -498,8 +499,8 @@ void Torch::computeBroadcastShape(PatternRewriter &rewriter, Location loc, ranks.push_back(shape.size()); } - Value torchCstOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value torchCstOne = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(1)); auto maxRankItr = std::max_element(ranks.begin(), ranks.end()); unsigned maxRank = *maxRankItr; @@ -515,8 +516,8 @@ void Torch::computeBroadcastShape(PatternRewriter &rewriter, Location loc, for (auto [idx, input] : llvm::enumerate(inputs)) { int sizeDimIdx = ranks[idx] - i - 1; if (sizeDimIdx >= 0) { - auto sizeDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(sizeDimIdx)); + auto sizeDim = Torch::ConstantIntOp::create( + rewriter, loc, rewriter.getI64IntegerAttr(sizeDimIdx)); sizeInputs.push_back( rewriter.createOrFold(loc, input, sizeDim)); } @@ -526,29 +527,30 @@ void Torch::computeBroadcastShape(PatternRewriter &rewriter, Location loc, // which is the maximum of dimension sizes across all inputs Value maxShapeVal = sizeInputs.front(); for (auto sizeInput : sizeInputs) { - maxShapeVal = rewriter.create(loc, maxShapeVal, sizeInput); + maxShapeVal = PrimMaxIntOp::create(rewriter, loc, maxShapeVal, sizeInput); } maxShapeValues.push_back(maxShapeVal); SmallVector predicates; for (auto sizeVal : sizeInputs) { Value cmpSizeEquals = - rewriter.create(loc, sizeVal, maxShapeVal); + Torch::AtenEqIntOp::create(rewriter, loc, sizeVal, maxShapeVal); Value cmpSizeEqualsOne = - rewriter.create(loc, sizeVal, torchCstOne); - Value anyBoolOpList = rewriter.create( - loc, Torch::ListType::get(cmpSizeEquals.getType()), + Torch::AtenEqIntOp::create(rewriter, loc, sizeVal, torchCstOne); + Value anyBoolOpList = PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(cmpSizeEquals.getType()), SmallVector{cmpSizeEquals, cmpSizeEqualsOne}); - Value cmp = rewriter.create(loc, anyBoolOpList); + Value cmp = Torch::AtenAnyBoolOp::create(rewriter, loc, anyBoolOpList); predicates.push_back(cmp); } if (!predicates.empty()) { - Value anyBoolOpList = rewriter.create( - loc, Torch::ListType::get(predicates.front().getType()), predicates); - Value cmp = rewriter.create(loc, anyBoolOpList); - rewriter.create( - loc, cmp, "tensors are not broadcast compatible"); + Value anyBoolOpList = PrimListConstructOp::create( + rewriter, loc, Torch::ListType::get(predicates.front().getType()), + predicates); + Value cmp = Torch::AtenAllBoolOp::create(rewriter, loc, anyBoolOpList); + Torch::RuntimeAssertOp::create(rewriter, loc, cmp, + "tensors are not broadcast compatible"); } } @@ -558,8 +560,8 @@ void Torch::computeBroadcastShape(PatternRewriter &rewriter, Location loc, Value shapeTensor = inputs[maxRankIdx]; for (unsigned i = 0; i < resultShape.size(); i++) { - Value sizeDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(i)); + Value sizeDim = Torch::ConstantIntOp::create(rewriter, loc, + rewriter.getI64IntegerAttr(i)); resultShapeValue.push_back( rewriter.createOrFold(loc, shapeTensor, sizeDim)); } @@ -660,12 +662,12 @@ Value Torch::createInitTensor(PatternRewriter &rewriter, Location loc, BaseTensorType resultType, Value scalar, Value sizeList) { assert(resultType.hasDtype() && "result must have dtype"); - Value noneVal = rewriter.create(loc); + Value noneVal = ConstantNoneOp::create(rewriter, loc); Value dtype = getDtypeIntValueForType(rewriter, loc, resultType.getDtype()); - return rewriter.create(loc, resultType, sizeList, scalar, dtype, - /*layout=*/noneVal, - /*device=*/noneVal, - /*memory_format=*/noneVal); + return AtenFullOp::create(rewriter, loc, resultType, sizeList, scalar, dtype, + /*layout=*/noneVal, + /*device=*/noneVal, + /*memory_format=*/noneVal); } // Helper to create a rank 0 tensor filled with the given `scalar`. `scalar` @@ -676,8 +678,9 @@ Value Torch::createRank0Tensor(PatternRewriter &rewriter, Location loc, SmallVector sizes; BaseTensorType rank0TensorTy = cast( inputType.getWithSizesAndDtype(ArrayRef(sizes), inputType.getDtype())); - Value dimList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(inputType.getContext())), + Value dimList = PrimListConstructOp::create( + rewriter, loc, + Torch::ListType::get(Torch::IntType::get(inputType.getContext())), ValueRange{}); return createInitTensor(rewriter, loc, rank0TensorTy, scalar, dimList); } diff --git a/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp b/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp index 5d9122fd7bc6..183e8543e9fb 100644 --- a/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp +++ b/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp @@ -62,13 +62,14 @@ Operation *TorchConversionDialect::materializeConstant(OpBuilder &builder, Type type, Location loc) { if (auto integerType = dyn_cast(type)) - return builder.create(loc, cast(value)); + return Torch::ConstantIntOp::create(builder, loc, cast(value)); if (auto floatType = dyn_cast(type)) - return builder.create(loc, cast(value)); + return Torch::ConstantFloatOp::create(builder, loc, cast(value)); if (isa(type)) { - return builder.create(loc, cast(value)); + return Torch::ConstantBoolOp::create(builder, loc, + cast(value)); } return arith::ConstantOp::materialize(builder, value, type, loc); diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp index b37f7b41e268..f7ffb14f0602 100644 --- a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp @@ -38,14 +38,14 @@ static void setupValueTensorToBuiltinTensorConversion( assert(inputs.size() == 1); if (!isa(inputs[0].getType())) return {}; - return builder.create(loc, type, inputs[0]); + return ToBuiltinTensorOp::create(builder, loc, type, inputs[0]); }); auto sourceMaterialization = [](OpBuilder &builder, Torch::ValueTensorType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); assert(isa(inputs[0].getType())); - return builder.create(loc, type, inputs[0]); + return FromBuiltinTensorOp::create(builder, loc, type, inputs[0]); }; typeConverter.addSourceMaterialization(sourceMaterialization); } @@ -64,13 +64,13 @@ static void setupTorchBoolToI1Conversion(ConversionTarget &target, return Value(); assert(inputs.size() == 1); assert(isa(inputs[0].getType())); - return builder.create(loc, inputs[0]).getResult(); + return ToI1Op::create(builder, loc, inputs[0]).getResult(); }); auto sourceMaterialization = [](OpBuilder &builder, Torch::BoolType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); assert(isa(inputs[0].getType())); - return builder.create(loc, inputs[0]); + return FromI1Op::create(builder, loc, inputs[0]); }; typeConverter.addSourceMaterialization(sourceMaterialization); } @@ -98,7 +98,7 @@ static void setupTorchIntToI64Conversion(ConversionTarget &target, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); assert(isa(inputs[0].getType())); - return builder.create(loc, inputs[0]); + return FromI64Op::create(builder, loc, inputs[0]); }; typeConverter.addSourceMaterialization(sourceMaterialization); } @@ -114,13 +114,13 @@ static void setupTorchFloatToF64Conversion(ConversionTarget &target, Location loc) -> Value { assert(inputs.size() == 1); assert(isa(inputs[0].getType())); - return builder.create(loc, inputs[0]).getResult(); + return ToF64Op::create(builder, loc, inputs[0]).getResult(); }); auto sourceMaterialization = [](OpBuilder &builder, Torch::FloatType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); assert(isa(inputs[0].getType())); - return builder.create(loc, inputs[0]); + return FromF64Op::create(builder, loc, inputs[0]); }; typeConverter.addSourceMaterialization(sourceMaterialization); } @@ -144,13 +144,13 @@ static void setupTorchGeneratorToI64Conversion(ConversionTarget &target, if (!isa(inputs[0].getType())) return Value(); assert(inputs.size() == 1); - return builder.create(loc, inputs[0]).getResult(); + return GeneratorToI64Op::create(builder, loc, inputs[0]).getResult(); }); auto sourceMaterialization = [](OpBuilder &builder, Torch::GeneratorType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); assert(isa(inputs[0].getType())); - return builder.create(loc, inputs[0]); + return I64ToGeneratorOp::create(builder, loc, inputs[0]); }; typeConverter.addSourceMaterialization(sourceMaterialization); } diff --git a/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp b/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp index 5c30889c45a8..a4c28c2c3160 100644 --- a/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp +++ b/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp @@ -106,8 +106,8 @@ class ConvertCustomQuantizedMatmulOp : public OpConversionPattern { RankedTensorType lhsExpandedType = RankedTensorType::get(lhsExpandedShape, elementType); SmallVector lhsReassociation = {{0}, {1}, {2, 3}}; - Value lhsExpanded = rewriter.create( - loc, lhsExpandedType, lhs, lhsReassociation); + Value lhsExpanded = tensor::ExpandShapeOp::create( + rewriter, loc, lhsExpandedType, lhs, lhsReassociation); // expand rhs std::vector rhsExpandedShape = {rhsShape[0], rhsReductDimSize / gs, @@ -115,23 +115,23 @@ class ConvertCustomQuantizedMatmulOp : public OpConversionPattern { RankedTensorType rhsExpandedType = RankedTensorType::get(rhsExpandedShape, rhsElementType); SmallVector rhsReassociation = {{0}, {1, 2}}; - Value rhsExpanded = rewriter.create( - loc, rhsExpandedType, rhsQuant, rhsReassociation); - Value cst0 = rewriter.create( - loc, FloatAttr::get(elementType, 0.0)); + Value rhsExpanded = tensor::ExpandShapeOp::create( + rewriter, loc, rhsExpandedType, rhsQuant, rhsReassociation); + Value cst0 = arith::ConstantOp::create(rewriter, loc, + FloatAttr::get(elementType, 0.0)); Value emptyDequant = - rewriter.create(loc, rhsExpandedShape, elementType); + tensor::EmptyOp::create(rewriter, loc, rhsExpandedShape, elementType); SmallVector dynDims; for (int i = 0; i < lhsType.getRank(); i++) { if (lhsType.isDynamicDim(i)) { - dynDims.push_back(rewriter.create(loc, lhs, i)); + dynDims.push_back(tensor::DimOp::create(rewriter, loc, lhs, i)); } } - Value empty = rewriter.create(loc, resultShape, - elementType, dynDims); + Value empty = tensor::EmptyOp::create(rewriter, loc, resultShape, + elementType, dynDims); Value output = - rewriter.create(loc, cst0, empty).getResult(0); + linalg::FillOp::create(rewriter, loc, cst0, empty).getResult(0); AffineExpr d0, d1, d2, d3, d4; bindDims(getContext(), d0, d1, d2, d3, d4); @@ -152,39 +152,36 @@ class ConvertCustomQuantizedMatmulOp : public OpConversionPattern { utils::IteratorType::reduction}; Value rhsDequant = - rewriter - .create( - loc, emptyDequant.getType(), - ValueRange{rhsExpanded, scales, zps}, emptyDequant, - /*indexingMaps=*/dqIndexingMaps, - /*iteratorTypes=*/dequantIteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value w = args[0], scale = args[1], zeroPoint = args[2]; - Value extw = - b.create(loc, rewriter.getI32Type(), w); - Value fp_extw = b.create( - loc, rewriter.getF16Type(), extw); - Value shifted = - b.create(loc, fp_extw, zeroPoint); - Value dqw = b.create(loc, shifted, scale); - b.create(loc, dqw); - }) + linalg::GenericOp::create( + rewriter, loc, emptyDequant.getType(), + ValueRange{rhsExpanded, scales, zps}, emptyDequant, + /*indexingMaps=*/dqIndexingMaps, + /*iteratorTypes=*/dequantIteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value w = args[0], scale = args[1], zeroPoint = args[2]; + Value extw = + arith::ExtUIOp::create(b, loc, rewriter.getI32Type(), w); + Value fp_extw = + arith::UIToFPOp::create(b, loc, rewriter.getF16Type(), extw); + Value shifted = arith::SubFOp::create(b, loc, fp_extw, zeroPoint); + Value dqw = arith::MulFOp::create(b, loc, shifted, scale); + linalg::YieldOp::create(b, loc, dqw); + }) .getResult(0); - Value matmulDequant = - rewriter - .create( - loc, output.getType(), ValueRange{lhsExpanded, rhsDequant}, - output, - /*indexingMaps=*/matIndexingMaps, - /*iteratorTypes=*/matmulIteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value l = args[0], r = args[1], out = args[2]; - Value pd = b.create(loc, l, r); - Value ac = b.create(loc, pd, out); - b.create(loc, ac); - }) - .getResult(0); + Value matmulDequant = linalg::GenericOp::create( + rewriter, loc, output.getType(), + ValueRange{lhsExpanded, rhsDequant}, output, + /*indexingMaps=*/matIndexingMaps, + /*iteratorTypes=*/matmulIteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value l = args[0], r = args[1], out = args[2]; + Value pd = arith::MulFOp::create(b, loc, l, r); + Value ac = + arith::AddFOp::create(b, loc, pd, out); + linalg::YieldOp::create(b, loc, ac); + }) + .getResult(0); rewriter.replaceOpWithNewOp(op, resultType, matmulDequant); return success(); diff --git a/lib/RefBackend/RefBackend.cpp b/lib/RefBackend/RefBackend.cpp index 642510f85633..89c7fb5df21a 100644 --- a/lib/RefBackend/RefBackend.cpp +++ b/lib/RefBackend/RefBackend.cpp @@ -120,8 +120,8 @@ static void replaceReturnWithCall(OpBuilder b, func::ReturnOp op, StringRef funcName, TypeRange retTypes, SmallVectorImpl &vals, SmallVectorImpl &toErase) { - b.create(op.getLoc(), funcName, TypeRange({}), vals); - b.create(op.getLoc()); + mlir::func::CallOp::create(b, op.getLoc(), funcName, TypeRange({}), vals); + mlir::func::ReturnOp::create(b, op.getLoc()); toErase.push_back(op); } @@ -155,7 +155,7 @@ static LogicalResult mungeFunction( "got ", type); } - auto cast = b.create(arg.getLoc(), type, arg); + auto cast = memref::CastOp::create(b, arg.getLoc(), type, arg); arg.replaceAllUsesExcept(cast, cast); arg.setType(getAbiTypeForMemRef(type)); newArgTypes.push_back(arg.getType()); @@ -176,8 +176,8 @@ static LogicalResult mungeFunction( retType = UnrankedMemRefType::get(elemType, 0); // Cast to unranked memref type before sending it as a function // argument. - retVal = b.create( - op.getLoc(), getAbiTypeForMemRef(types[en.index()]), retVal); + retVal = memref::CastOp::create( + b, op.getLoc(), getAbiTypeForMemRef(types[en.index()]), retVal); } retTypes.push_back(retType); retVals.push_back(retVal); @@ -210,8 +210,8 @@ class MungeCallingConventions // Create FuncOp for consumeFuncReturnFuncs that are used. for (auto &p : invokedConsumeFuncReturnFuncs) { - auto consumeFuncReturnFunc = b.create( - module.getLoc(), p.first, + auto consumeFuncReturnFunc = func::FuncOp::create( + b, module.getLoc(), p.first, FunctionType::get(module.getContext(), p.second, {})); consumeFuncReturnFunc.setPrivate(); addEmitCInterfaceAttr(consumeFuncReturnFunc); @@ -239,13 +239,13 @@ static LogicalResult bufferizeMLProgramGlobalOp(ml_program::GlobalOp globalOp, MemRefType::get(tensorType.getShape(), tensorType.getElementType()); b.setInsertionPointToStart(globalOp->getParentOfType().getBody()); - b.create( - UnknownLoc::get(b.getContext()), globalOp.getSymName(), - /*sym_visibility=*/globalOp.getSymVisibilityAttr(), - /*type=*/memrefType, - /*initial_value=*/globalOp.getValue().value(), - /*constant=*/globalOp.getIsMutable() ? false : true, - /*alignment=*/nullptr); + memref::GlobalOp::create(b, UnknownLoc::get(b.getContext()), + globalOp.getSymName(), + /*sym_visibility=*/globalOp.getSymVisibilityAttr(), + /*type=*/memrefType, + /*initial_value=*/globalOp.getValue().value(), + /*constant=*/globalOp.getIsMutable() ? false : true, + /*alignment=*/nullptr); return success(); } @@ -257,11 +257,11 @@ bufferizeMLProgramGlobaLoadOp(ml_program::GlobalLoadOp globalLoadOp, MemRefType::get(tensorType.getShape(), tensorType.getElementType()); b.setInsertionPoint(globalLoadOp); - Value globalVal = b.create( - globalLoadOp.getLoc(), memrefType, + Value globalVal = memref::GetGlobalOp::create( + b, globalLoadOp.getLoc(), memrefType, globalLoadOp.getGlobalAttr().getLeafReference()); - globalVal = b.create(globalLoadOp->getLoc(), - tensorType, globalVal); + globalVal = bufferization::ToTensorOp::create(b, globalLoadOp->getLoc(), + tensorType, globalVal); globalLoadOp->getResult(0).replaceAllUsesWith(globalVal); return success(); } @@ -276,12 +276,12 @@ bufferizeMLProgramGlobaStoreOp(ml_program::GlobalStoreOp globalStoreOp, MemRefType::get(tensorType.getShape(), tensorType.getElementType()); b.setInsertionPoint(globalStoreOp); - Value memref = b.create( - globalStoreOp.getLoc(), memrefType, + Value memref = memref::GetGlobalOp::create( + b, globalStoreOp.getLoc(), memrefType, globalStoreOp.getGlobalAttr().getLeafReference()); - Value copyValue = b.create( - globalStoreOp->getLoc(), memrefType, globalStoreOp.getValue()); - b.create(globalStoreOp->getLoc(), copyValue, memref); + Value copyValue = bufferization::ToBufferOp::create( + b, globalStoreOp->getLoc(), memrefType, globalStoreOp.getValue()); + memref::CopyOp::create(b, globalStoreOp->getLoc(), copyValue, memref); return success(); } @@ -396,14 +396,14 @@ Operation *createLinalgCopyOp(OpBuilder &b, Location loc, Value from, AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext()); SmallVector iteratorTypes(memrefTypeTo.getRank(), utils::IteratorType::parallel); - return b.create( - loc, + return linalg::GenericOp::create( + b, loc, /*inputs=*/from, /*outputs=*/to, /*indexingMaps=*/llvm::ArrayRef({id, id}), /*iteratorTypes=*/iteratorTypes, [](OpBuilder &b, Location loc, ValueRange args) { - b.create(loc, args.front()); + linalg::YieldOp::create(b, loc, args.front()); }); }