Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 36 additions & 36 deletions mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,34 +36,34 @@ struct ConvertIndexCeilDivS : mlir::ConvertOpToLLVMPattern<CeilDivSOp> {
Location loc = op.getLoc();
Value n = adaptor.getLhs();
Value m = adaptor.getRhs();
Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0);
Value posOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1);
Value negOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), -1);
Value zero = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 0);
Value posOne = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 1);
Value negOne = LLVM::ConstantOp::create(rewriter, loc, n.getType(), -1);

// Compute `x`.
Value mPos =
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt, m, zero);
Value x = rewriter.create<LLVM::SelectOp>(loc, mPos, negOne, posOne);
LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::sgt, m, zero);
Value x = LLVM::SelectOp::create(rewriter, loc, mPos, negOne, posOne);

// Compute the positive result.
Value nPlusX = rewriter.create<LLVM::AddOp>(loc, n, x);
Value nPlusXDivM = rewriter.create<LLVM::SDivOp>(loc, nPlusX, m);
Value posRes = rewriter.create<LLVM::AddOp>(loc, nPlusXDivM, posOne);
Value nPlusX = LLVM::AddOp::create(rewriter, loc, n, x);
Value nPlusXDivM = LLVM::SDivOp::create(rewriter, loc, nPlusX, m);
Value posRes = LLVM::AddOp::create(rewriter, loc, nPlusXDivM, posOne);

// Compute the negative result.
Value negN = rewriter.create<LLVM::SubOp>(loc, zero, n);
Value negNDivM = rewriter.create<LLVM::SDivOp>(loc, negN, m);
Value negRes = rewriter.create<LLVM::SubOp>(loc, zero, negNDivM);
Value negN = LLVM::SubOp::create(rewriter, loc, zero, n);
Value negNDivM = LLVM::SDivOp::create(rewriter, loc, negN, m);
Value negRes = LLVM::SubOp::create(rewriter, loc, zero, negNDivM);

// Pick the positive result if `n` and `m` have the same sign and `n` is
// non-zero, i.e. `(n > 0) == (m > 0) && n != 0`.
Value nPos =
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt, n, zero);
Value sameSign =
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, nPos, mPos);
LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::sgt, n, zero);
Value sameSign = LLVM::ICmpOp::create(rewriter, loc,
LLVM::ICmpPredicate::eq, nPos, mPos);
Value nNonZero =
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, n, zero);
Value cmp = rewriter.create<LLVM::AndOp>(loc, sameSign, nNonZero);
LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::ne, n, zero);
Value cmp = LLVM::AndOp::create(rewriter, loc, sameSign, nNonZero);
rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, posRes, negRes);
return success();
}
Expand All @@ -83,17 +83,17 @@ struct ConvertIndexCeilDivU : mlir::ConvertOpToLLVMPattern<CeilDivUOp> {
Location loc = op.getLoc();
Value n = adaptor.getLhs();
Value m = adaptor.getRhs();
Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0);
Value one = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1);
Value zero = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 0);
Value one = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 1);

// Compute the non-zero result.
Value minusOne = rewriter.create<LLVM::SubOp>(loc, n, one);
Value quotient = rewriter.create<LLVM::UDivOp>(loc, minusOne, m);
Value plusOne = rewriter.create<LLVM::AddOp>(loc, quotient, one);
Value minusOne = LLVM::SubOp::create(rewriter, loc, n, one);
Value quotient = LLVM::UDivOp::create(rewriter, loc, minusOne, m);
Value plusOne = LLVM::AddOp::create(rewriter, loc, quotient, one);

// Pick the result.
Value cmp =
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, n, zero);
LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::eq, n, zero);
rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, zero, plusOne);
return success();
}
Expand All @@ -114,32 +114,32 @@ struct ConvertIndexFloorDivS : mlir::ConvertOpToLLVMPattern<FloorDivSOp> {
Location loc = op.getLoc();
Value n = adaptor.getLhs();
Value m = adaptor.getRhs();
Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0);
Value posOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1);
Value negOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), -1);
Value zero = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 0);
Value posOne = LLVM::ConstantOp::create(rewriter, loc, n.getType(), 1);
Value negOne = LLVM::ConstantOp::create(rewriter, loc, n.getType(), -1);

// Compute `x`.
Value mNeg =
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, m, zero);
Value x = rewriter.create<LLVM::SelectOp>(loc, mNeg, posOne, negOne);
LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::slt, m, zero);
Value x = LLVM::SelectOp::create(rewriter, loc, mNeg, posOne, negOne);

// Compute the negative result.
Value xMinusN = rewriter.create<LLVM::SubOp>(loc, x, n);
Value xMinusNDivM = rewriter.create<LLVM::SDivOp>(loc, xMinusN, m);
Value negRes = rewriter.create<LLVM::SubOp>(loc, negOne, xMinusNDivM);
Value xMinusN = LLVM::SubOp::create(rewriter, loc, x, n);
Value xMinusNDivM = LLVM::SDivOp::create(rewriter, loc, xMinusN, m);
Value negRes = LLVM::SubOp::create(rewriter, loc, negOne, xMinusNDivM);

// Compute the positive result.
Value posRes = rewriter.create<LLVM::SDivOp>(loc, n, m);
Value posRes = LLVM::SDivOp::create(rewriter, loc, n, m);

// Pick the negative result if `n` and `m` have different signs and `n` is
// non-zero, i.e. `(n < 0) != (m < 0) && n != 0`.
Value nNeg =
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, n, zero);
Value diffSign =
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, nNeg, mNeg);
LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::slt, n, zero);
Value diffSign = LLVM::ICmpOp::create(rewriter, loc,
LLVM::ICmpPredicate::ne, nNeg, mNeg);
Value nNonZero =
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, n, zero);
Value cmp = rewriter.create<LLVM::AndOp>(loc, diffSign, nNonZero);
LLVM::ICmpOp::create(rewriter, loc, LLVM::ICmpPredicate::ne, n, zero);
Value cmp = LLVM::AndOp::create(rewriter, loc, diffSign, nNonZero);
rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, negRes, posRes);
return success();
}
Expand Down
85 changes: 43 additions & 42 deletions mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,33 +111,33 @@ struct ConvertIndexCeilDivSPattern final : OpConversionPattern<CeilDivSOp> {
Value m = adaptor.getRhs();

// Define the constants
Value zero = rewriter.create<spirv::ConstantOp>(
loc, n_type, IntegerAttr::get(n_type, 0));
Value posOne = rewriter.create<spirv::ConstantOp>(
loc, n_type, IntegerAttr::get(n_type, 1));
Value negOne = rewriter.create<spirv::ConstantOp>(
loc, n_type, IntegerAttr::get(n_type, -1));
Value zero = spirv::ConstantOp::create(rewriter, loc, n_type,
IntegerAttr::get(n_type, 0));
Value posOne = spirv::ConstantOp::create(rewriter, loc, n_type,
IntegerAttr::get(n_type, 1));
Value negOne = spirv::ConstantOp::create(rewriter, loc, n_type,
IntegerAttr::get(n_type, -1));

// Compute `x`.
Value mPos = rewriter.create<spirv::SGreaterThanOp>(loc, m, zero);
Value x = rewriter.create<spirv::SelectOp>(loc, mPos, negOne, posOne);
Value mPos = spirv::SGreaterThanOp::create(rewriter, loc, m, zero);
Value x = spirv::SelectOp::create(rewriter, loc, mPos, negOne, posOne);

// Compute the positive result.
Value nPlusX = rewriter.create<spirv::IAddOp>(loc, n, x);
Value nPlusXDivM = rewriter.create<spirv::SDivOp>(loc, nPlusX, m);
Value posRes = rewriter.create<spirv::IAddOp>(loc, nPlusXDivM, posOne);
Value nPlusX = spirv::IAddOp::create(rewriter, loc, n, x);
Value nPlusXDivM = spirv::SDivOp::create(rewriter, loc, nPlusX, m);
Value posRes = spirv::IAddOp::create(rewriter, loc, nPlusXDivM, posOne);

// Compute the negative result.
Value negN = rewriter.create<spirv::ISubOp>(loc, zero, n);
Value negNDivM = rewriter.create<spirv::SDivOp>(loc, negN, m);
Value negRes = rewriter.create<spirv::ISubOp>(loc, zero, negNDivM);
Value negN = spirv::ISubOp::create(rewriter, loc, zero, n);
Value negNDivM = spirv::SDivOp::create(rewriter, loc, negN, m);
Value negRes = spirv::ISubOp::create(rewriter, loc, zero, negNDivM);

// Pick the positive result if `n` and `m` have the same sign and `n` is
// non-zero, i.e. `(n > 0) == (m > 0) && n != 0`.
Value nPos = rewriter.create<spirv::SGreaterThanOp>(loc, n, zero);
Value sameSign = rewriter.create<spirv::LogicalEqualOp>(loc, nPos, mPos);
Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero);
Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, sameSign, nNonZero);
Value nPos = spirv::SGreaterThanOp::create(rewriter, loc, n, zero);
Value sameSign = spirv::LogicalEqualOp::create(rewriter, loc, nPos, mPos);
Value nNonZero = spirv::INotEqualOp::create(rewriter, loc, n, zero);
Value cmp = spirv::LogicalAndOp::create(rewriter, loc, sameSign, nNonZero);
rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
return success();
}
Expand All @@ -161,18 +161,18 @@ struct ConvertIndexCeilDivUPattern final : OpConversionPattern<CeilDivUOp> {
Value m = adaptor.getRhs();

// Define the constants
Value zero = rewriter.create<spirv::ConstantOp>(
loc, n_type, IntegerAttr::get(n_type, 0));
Value one = rewriter.create<spirv::ConstantOp>(loc, n_type,
IntegerAttr::get(n_type, 1));
Value zero = spirv::ConstantOp::create(rewriter, loc, n_type,
IntegerAttr::get(n_type, 0));
Value one = spirv::ConstantOp::create(rewriter, loc, n_type,
IntegerAttr::get(n_type, 1));

// Compute the non-zero result.
Value minusOne = rewriter.create<spirv::ISubOp>(loc, n, one);
Value quotient = rewriter.create<spirv::UDivOp>(loc, minusOne, m);
Value plusOne = rewriter.create<spirv::IAddOp>(loc, quotient, one);
Value minusOne = spirv::ISubOp::create(rewriter, loc, n, one);
Value quotient = spirv::UDivOp::create(rewriter, loc, minusOne, m);
Value plusOne = spirv::IAddOp::create(rewriter, loc, quotient, one);

// Pick the result
Value cmp = rewriter.create<spirv::IEqualOp>(loc, n, zero);
Value cmp = spirv::IEqualOp::create(rewriter, loc, n, zero);
rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, zero, plusOne);
return success();
}
Expand All @@ -197,32 +197,33 @@ struct ConvertIndexFloorDivSPattern final : OpConversionPattern<FloorDivSOp> {
Value m = adaptor.getRhs();

// Define the constants
Value zero = rewriter.create<spirv::ConstantOp>(
loc, n_type, IntegerAttr::get(n_type, 0));
Value posOne = rewriter.create<spirv::ConstantOp>(
loc, n_type, IntegerAttr::get(n_type, 1));
Value negOne = rewriter.create<spirv::ConstantOp>(
loc, n_type, IntegerAttr::get(n_type, -1));
Value zero = spirv::ConstantOp::create(rewriter, loc, n_type,
IntegerAttr::get(n_type, 0));
Value posOne = spirv::ConstantOp::create(rewriter, loc, n_type,
IntegerAttr::get(n_type, 1));
Value negOne = spirv::ConstantOp::create(rewriter, loc, n_type,
IntegerAttr::get(n_type, -1));

// Compute `x`.
Value mNeg = rewriter.create<spirv::SLessThanOp>(loc, m, zero);
Value x = rewriter.create<spirv::SelectOp>(loc, mNeg, posOne, negOne);
Value mNeg = spirv::SLessThanOp::create(rewriter, loc, m, zero);
Value x = spirv::SelectOp::create(rewriter, loc, mNeg, posOne, negOne);

// Compute the negative result
Value xMinusN = rewriter.create<spirv::ISubOp>(loc, x, n);
Value xMinusNDivM = rewriter.create<spirv::SDivOp>(loc, xMinusN, m);
Value negRes = rewriter.create<spirv::ISubOp>(loc, negOne, xMinusNDivM);
Value xMinusN = spirv::ISubOp::create(rewriter, loc, x, n);
Value xMinusNDivM = spirv::SDivOp::create(rewriter, loc, xMinusN, m);
Value negRes = spirv::ISubOp::create(rewriter, loc, negOne, xMinusNDivM);

// Compute the positive result.
Value posRes = rewriter.create<spirv::SDivOp>(loc, n, m);
Value posRes = spirv::SDivOp::create(rewriter, loc, n, m);

// Pick the negative result if `n` and `m` have different signs and `n` is
// non-zero, i.e. `(n < 0) != (m < 0) && n != 0`.
Value nNeg = rewriter.create<spirv::SLessThanOp>(loc, n, zero);
Value diffSign = rewriter.create<spirv::LogicalNotEqualOp>(loc, nNeg, mNeg);
Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero);
Value nNeg = spirv::SLessThanOp::create(rewriter, loc, n, zero);
Value diffSign =
spirv::LogicalNotEqualOp::create(rewriter, loc, nNeg, mNeg);
Value nNonZero = spirv::INotEqualOp::create(rewriter, loc, n, zero);

Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, diffSign, nNonZero);
Value cmp = spirv::LogicalAndOp::create(rewriter, loc, diffSign, nNonZero);
rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
return success();
}
Expand Down
Loading
Loading