Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[submodule "externals/llvm-project"]
path = externals/llvm-project
url = https://github.com/iree-org/llvm-project.git
url = https://github.com/llvm/llvm-project.git
[submodule "externals/stablehlo"]
path = externals/stablehlo
url = https://github.com/openxla/stablehlo.git
4 changes: 2 additions & 2 deletions include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ Torch::ValueTensorType getQTorchTypeFromTorchIntType(Type ty);
template <typename T>
Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter,
Value &ofItem) {
return rewriter.create<Torch::AtenItemOp>(binder.getLoc(),
rewriter.getType<T>(), ofItem);
return Torch::AtenItemOp::create(rewriter, binder.getLoc(),
rewriter.getType<T>(), ofItem);
}

LogicalResult OnnxLstmExpander(OpBinder binder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ static LogicalResult getOrCreateGlobalVariableForSeed(OpBuilder &b,
}

b.setInsertionPointToStart(module.getBody());
b.create<ml_program::GlobalOp>(
UnknownLoc::get(b.getContext()),
ml_program::GlobalOp::create(
b, UnknownLoc::get(b.getContext()),
/*sym_name=*/getSeedGobalVarName(),
/*type=*/tensorType,
/*is_mutable=*/true,
Expand Down Expand Up @@ -71,25 +71,25 @@ class ConvertGetNextSeedOp : public OpConversionPattern<GetNextSeedOp> {
// Refer to https://en.wikipedia.org/wiki/Linear_congruential_generator.
// Get the current seed value.
auto tensorType = RankedTensorType::get({}, rewriter.getI64Type());
Value globalVar = rewriter.create<ml_program::GlobalLoadOp>(
loc, tensorType,
Value globalVar = ml_program::GlobalLoadOp::create(
rewriter, loc, tensorType,
SymbolRefAttr::get(op->getContext(), getSeedGobalVarName()));
Value currentSeed = rewriter.create<tensor::ExtractOp>(loc, globalVar);
Value currentSeed = tensor::ExtractOp::create(rewriter, loc, globalVar);

// The value of multiplier and incrementStep are referenced from
// https://en.wikipedia.org/wiki/Linear_congruential_generator for 2^64.
Value multiplier = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(6364136223846793005));
Value incrementStep = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(1442695040888963407));
Value multiplier = arith::ConstantOp::create(
rewriter, loc, rewriter.getI64IntegerAttr(6364136223846793005));
Value incrementStep = arith::ConstantOp::create(
rewriter, loc, rewriter.getI64IntegerAttr(1442695040888963407));
// temp = multiplier * currentSeed + incrementStep
Value mul = rewriter.create<arith::MulIOp>(loc, currentSeed, multiplier);
Value seed = rewriter.create<arith::AddIOp>(loc, mul, incrementStep);
Value mul = arith::MulIOp::create(rewriter, loc, currentSeed, multiplier);
Value seed = arith::AddIOp::create(rewriter, loc, mul, incrementStep);
globalVar =
rewriter.create<tensor::InsertOp>(loc, seed, globalVar, ValueRange());
rewriter.create<ml_program::GlobalStoreOp>(
loc, SymbolRefAttr::get(op->getContext(), getSeedGobalVarName()),
globalVar);
tensor::InsertOp::create(rewriter, loc, seed, globalVar, ValueRange());
ml_program::GlobalStoreOp::create(
rewriter, loc,
SymbolRefAttr::get(op->getContext(), getSeedGobalVarName()), globalVar);
rewriter.replaceOp(op, seed);
return success();
}
Expand Down
435 changes: 218 additions & 217 deletions lib/Conversion/TorchOnnxToTorch/ComMicrosoftDomain.cpp

Large diffs are not rendered by default.

1,328 changes: 683 additions & 645 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp

Large diffs are not rendered by default.

1,819 changes: 927 additions & 892 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp

Large diffs are not rendered by default.

2,094 changes: 1,086 additions & 1,008 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp

Large diffs are not rendered by default.

499 changes: 252 additions & 247 deletions lib/Conversion/TorchOnnxToTorch/OnnxRecurrentLayerOpExpanders.cpp

Large diffs are not rendered by default.

47 changes: 23 additions & 24 deletions lib/Conversion/TorchOnnxToTorch/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ Value mlir::torch::onnx_c::createConstantIntList(
ArrayRef<int64_t> cstInput) {
SmallVector<Value> cstValue;
for (int64_t i : cstInput) {
cstValue.push_back(rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(i)));
cstValue.push_back(Torch::ConstantIntOp::create(
rewriter, binder.getLoc(), rewriter.getI64IntegerAttr(i)));
}
return rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
return Torch::PrimListConstructOp::create(
rewriter, binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(binder.op->getContext())),
cstValue);
}
Expand Down Expand Up @@ -109,12 +109,12 @@ LogicalResult mlir::torch::onnx_c::createTorchTransposeOp(
if (failed(getTransposedType(cast<Torch::BaseTensorType>(input.getType()),
dimA, dimB, transposedType)))
return failure();
Value cstDimA = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(dimA));
Value cstDimB = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(dimB));
transposed = rewriter.create<Torch::AtenTransposeIntOp>(
loc, transposedType, input, cstDimA, cstDimB);
Value cstDimA = Torch::ConstantIntOp::create(
rewriter, loc, rewriter.getI64IntegerAttr(dimA));
Value cstDimB = Torch::ConstantIntOp::create(
rewriter, loc, rewriter.getI64IntegerAttr(dimB));
transposed = Torch::AtenTransposeIntOp::create(rewriter, loc, transposedType,
input, cstDimA, cstDimB);
return success();
}

Expand All @@ -127,19 +127,19 @@ LogicalResult mlir::torch::onnx_c::createTorchPermuteOp(
permuteDims, permutedType)))
return failure();
Value permuteDimsList = createConstantIntList(binder, rewriter, permuteDims);
permuted = rewriter.create<Torch::AtenPermuteOp>(loc, permutedType, input,
permuteDimsList);
permuted = Torch::AtenPermuteOp::create(rewriter, loc, permutedType, input,
permuteDimsList);
return success();
}

Value mlir::torch::onnx_c::createActivationByName(ImplicitLocOpBuilder &b,
StringRef name, Value input) {
if (name == "Sigmoid")
return b.create<Torch::AtenSigmoidOp>(input.getType(), input);
return Torch::AtenSigmoidOp::create(b, input.getType(), input);
if (name == "Tanh")
return b.create<Torch::AtenTanhOp>(input.getType(), input);
return Torch::AtenTanhOp::create(b, input.getType(), input);
if (name == "Relu")
return b.create<Torch::AtenReluOp>(input.getType(), input);
return Torch::AtenReluOp::create(b, input.getType(), input);
llvm_unreachable("Unsupported activation function");
}

Expand All @@ -158,23 +158,23 @@ LogicalResult mlir::torch::onnx_c::extractPerTensorQuantizationArguments(
if (!check(inScale) || !check(inZeroPoint))
return failure();

Value emptyList = rewriter.create<Torch::PrimListConstructOp>(
loc,
Value emptyList = Torch::PrimListConstructOp::create(
rewriter, loc,
rewriter.getType<Torch::ListType>(rewriter.getType<Torch::IntType>()),
ValueRange{});
auto extract = [&rewriter, &loc, &emptyList](Value v) {
auto vTy = cast<Torch::ValueTensorType>(v.getType());
if (!vTy.getSizes().empty()) {
vTy = rewriter.getType<Torch::ValueTensorType>(ArrayRef<int64_t>({}),
vTy.getOptionalDtype());
v = rewriter.create<Torch::AtenReshapeOp>(loc, vTy, v, emptyList);
v = Torch::AtenReshapeOp::create(rewriter, loc, vTy, v, emptyList);
}

Type extractTy = rewriter.getType<Torch::FloatType>();
if (isa<IntegerType>(vTy.getDtype()))
extractTy = rewriter.getType<Torch::IntType>();

return rewriter.create<Torch::AtenItemOp>(loc, extractTy, v);
return Torch::AtenItemOp::create(rewriter, loc, extractTy, v);
};

outScale = extract(inScale);
Expand All @@ -191,14 +191,13 @@ LogicalResult mlir::torch::onnx_c::createDequantizeTensor(
return failure();

Torch::ValueTensorType makeTensorTy = getQTorchTypeFromTorchIntType(inputTy);
Value quantizedInput =
rewriter.create<Torch::Aten_MakePerTensorQuantizedTensorOp>(
loc, makeTensorTy, input, scale, zeroPoint);
Value quantizedInput = Torch::Aten_MakePerTensorQuantizedTensorOp::create(
rewriter, loc, makeTensorTy, input, scale, zeroPoint);

Torch::ValueTensorType resultTy = rewriter.getType<Torch::ValueTensorType>(
inputTy.getSizes(), rewriter.getF32Type());
output = rewriter.create<Torch::AtenDequantizeSelfOp>(loc, resultTy,
quantizedInput);
output = Torch::AtenDequantizeSelfOp::create(rewriter, loc, resultTy,
quantizedInput);

return success();
}
20 changes: 10 additions & 10 deletions lib/Conversion/TorchToArith/TorchToArith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class ConvertAtenDimOp : public OpConversionPattern<AtenDimOp> {
matchAndRewrite(AtenDimOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto rank =
rewriter.create<tensor::RankOp>(op->getLoc(), adaptor.getSelf());
tensor::RankOp::create(rewriter, op->getLoc(), adaptor.getSelf());
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(
op, getTypeConverter()->convertType(op.getType()), rank);
return success();
Expand Down Expand Up @@ -96,8 +96,8 @@ class ConvertAtenNegIntOp : public OpConversionPattern<AtenNegIntOp> {
Value a = adaptor.getA();
rewriter.replaceOpWithNewOp<arith::SubIOp>(
op,
rewriter.create<arith::ConstantIntOp>(op.getLoc(), /*value=*/0,
/*bitwidth=*/64),
arith::ConstantIntOp::create(rewriter, op.getLoc(), /*value=*/0,
/*bitwidth=*/64),
a);
return success();
}
Expand All @@ -119,7 +119,7 @@ class ConvertAtenUnaryOpToFloatMathOp : public OpConversionPattern<AtenOp> {
this->getTypeConverter()->convertType(op->getResult(0).getType());
if (!isa<mlir::FloatType>(input.getType()))
input = convertScalarToDtype(rewriter, loc, input, rewriter.getF64Type());
Value result = rewriter.create<UnaryOp>(loc, input);
Value result = UnaryOp::create(rewriter, loc, input);
rewriter.replaceOp(op,
convertScalarToDtype(rewriter, loc, result, resultType));
return success();
Expand Down Expand Up @@ -347,7 +347,7 @@ class ConvertAtenAnyOrAllBoolOp : public OpConversionPattern<OpTy> {
rewriter, loc, this->getTypeConverter(), inputListTorchBool);
result = inputList[0];
for (unsigned i = 1; i < inputList.size(); i++)
result = rewriter.create<BinOp>(loc, result, inputList[i]);
result = BinOp::create(rewriter, loc, result, inputList[i]);
rewriter.replaceOp(op, result);
return success();
}
Expand Down Expand Up @@ -385,15 +385,15 @@ class ConvertAtenBoolLikeOp : public OpConversionPattern<OpTy> {
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Type inputType = adaptor.getA().getType();
Value cstZero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(inputType));
Value cstZero = arith::ConstantOp::create(rewriter, loc,
rewriter.getZeroAttr(inputType));
Value cstTrue =
rewriter.create<arith::ConstantOp>(loc, rewriter.getBoolAttr(true));
arith::ConstantOp::create(rewriter, loc, rewriter.getBoolAttr(true));
Value cstFalse =
rewriter.create<arith::ConstantOp>(loc, rewriter.getBoolAttr(false));
arith::ConstantOp::create(rewriter, loc, rewriter.getBoolAttr(false));

Value cmpPred;
cmpPred = rewriter.create<CmpOpTy>(loc, Pred, adaptor.getA(), cstZero);
cmpPred = CmpOpTy::create(rewriter, loc, Pred, adaptor.getA(), cstZero);
rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmpPred, cstTrue,
cstFalse);
return success();
Expand Down
Loading
Loading