Skip to content

Commit 57f7000

Browse files
authored
Add builtin function abs and numpy.abs (#178)
1 parent 1903cb5 commit 57f7000

File tree

7 files changed

+87
-3
lines changed

7 files changed

+87
-3
lines changed

mlir/lib/transforms/uplift_math.cpp

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,33 @@ struct UpliftMathCalls : public mlir::OpRewritePattern<mlir::func::CallOp> {
9090
}
9191
};
9292

93+
struct UpliftFabsCalls : public mlir::OpRewritePattern<mlir::func::CallOp> {
94+
using OpRewritePattern::OpRewritePattern;
95+
96+
mlir::LogicalResult
97+
matchAndRewrite(mlir::func::CallOp op,
98+
mlir::PatternRewriter &rewriter) const override {
99+
auto funcName = op.getCallee();
100+
if (funcName.empty())
101+
return mlir::failure();
102+
103+
if (funcName != "fabs" && funcName != "fabsf")
104+
return mlir::failure();
105+
106+
auto isNotValidType = [](mlir::Type t) {
107+
return !t.isa<mlir::FloatType>();
108+
};
109+
110+
if (op.getNumResults() != 1 || op.getNumOperands() != 1 ||
111+
llvm::any_of(op.getOperandTypes(), isNotValidType) ||
112+
llvm::any_of(op.getResultTypes(), isNotValidType))
113+
return mlir::failure();
114+
115+
rewriter.replaceOpWithNewOp<mlir::math::AbsOp>(op, op.operands()[0]);
116+
return mlir::success();
117+
}
118+
};
119+
93120
struct UpliftFma : public mlir::OpRewritePattern<mlir::arith::AddFOp> {
94121
using OpRewritePattern::OpRewritePattern;
95122

@@ -123,7 +150,7 @@ struct UpliftMathPass
123150
plier::DependentDialectsList<mlir::func::FuncDialect,
124151
mlir::arith::ArithmeticDialect,
125152
mlir::math::MathDialect>,
126-
UpliftMathCalls, UpliftFma> {};
153+
UpliftMathCalls, UpliftFabsCalls, UpliftFma> {};
127154
} // namespace
128155

129156
void plier::populateUpliftmathPatterns(mlir::MLIRContext &context,

numba_dpcomp/numba_dpcomp/mlir/builtin/funcs.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,18 @@ def func(builder, *args):
135135

136136
_gen_math_funcs()
137137
del _gen_math_funcs
138+
139+
140+
@register_func("abs", abs)
141+
def abs_impl(builder, arg):
142+
t = arg.type
143+
if is_int(t, builder):
144+
c = arg < 0
145+
return builder.select(c, -arg, arg)
146+
if is_float(t, builder):
147+
fname = "fabs"
148+
if t == builder.float32:
149+
fname = fname + "f"
150+
151+
res = builder.cast(0, t)
152+
return builder.external_call(fname, arg, res, decorate=False)

numba_dpcomp/numba_dpcomp/mlir/linalg_builder.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ def __eq__(self, o):
8383
def __ne__(self, o):
8484
return self._binop(self._context, self._ssa_val, o, "ne")
8585

86+
def __neg__(self):
87+
return self._unop(self._context, self._ssa_val, "-")
88+
89+
def __pos__(self):
90+
return self._unop(self._context, self._ssa_val, "+")
91+
8692
def __str__(self):
8793
return self._str(self._context, self._ssa_val)
8894

numba_dpcomp/numba_dpcomp/mlir/numpy/funcs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ def bool_type(builder, t):
206206
(register_func("numpy.cos", numpy.cos), f64_type, lambda a, b: math.cos(a)),
207207
(register_func("numpy.exp", numpy.exp), f64_type, lambda a, b: math.exp(a)),
208208
(register_func("numpy.tanh", numpy.tanh), f64_type, lambda a, b: math.tanh(a)),
209+
(register_func("numpy.abs", numpy.abs), None, lambda a, b: abs(a)),
209210
(
210211
register_func("numpy.logical_not", numpy.logical_not),
211212
bool_type,

numba_dpcomp/numba_dpcomp/mlir/tests/test_basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def py_func(a, b):
100100

101101

102102
@parametrize_function_variants(
103-
"py_func", ["lambda a: +a", "lambda a: -a", "lambda a: ~a",]
103+
"py_func", ["lambda a: +a", "lambda a: -a", "lambda a: ~a", "lambda a: abs(a)",]
104104
)
105105
@pytest.mark.parametrize("val", _test_values)
106106
def test_unary_ops(py_func, val, request):
@@ -184,7 +184,7 @@ def test_cast(py_func, val):
184184

185185

186186
@pytest.mark.parametrize("val", [1, 5, 5.5])
187-
@pytest.mark.parametrize("name", ["sqrt", "log", "exp", "sin", "cos", "erf", "tanh",])
187+
@pytest.mark.parametrize("name", ["sqrt", "log", "exp", "sin", "cos", "erf", "tanh"])
188188
def test_math_uplifting1(val, name):
189189
py_func = eval(f"lambda a: math.{name}(a)")
190190

numba_dpcomp/numba_dpcomp/mlir/tests/test_numpy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ def _vectorize_reference(func, arg1):
8787
"lambda a: np.cos(a)",
8888
"lambda a: np.exp(a)",
8989
"lambda a: np.tanh(a)",
90+
"lambda a: np.abs(a)",
91+
"lambda a: np.absolute(a)",
9092
"lambda a: a.size",
9193
"lambda a: a.T",
9294
"lambda a: a.T.T",

numba_dpcomp/numba_dpcomp/mlir_compiler/lib/py_linalg_resolver.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1742,6 +1742,38 @@ static py::object binopImpl(py::capsule context, py::capsule ssaVal,
17421742
plier::reportError("Unhandled binop type");
17431743
}
17441744

1745+
static py::object unopImpl(py::capsule context, py::capsule ssaVal,
1746+
py::str op) {
1747+
auto &ctx = getPyContext(context);
1748+
auto &builder = ctx.builder;
1749+
auto loc = ctx.loc;
1750+
auto val = unwrapMlir<mlir::Value>(ssaVal);
1751+
1752+
auto type = val.getType();
1753+
if (!type.isa<mlir::IntegerType, mlir::IndexType, mlir::FloatType>())
1754+
plier::reportError("Invalid unop arg type");
1755+
1756+
auto opName = static_cast<std::string>(op);
1757+
mlir::Value res;
1758+
if (opName == "+") {
1759+
res = val;
1760+
} else if (opName == "-") {
1761+
if (type.isa<mlir::FloatType>()) {
1762+
res = builder.create<mlir::arith::NegFOp>(loc, val);
1763+
} else {
1764+
auto signlessType = makeSignlessType(type);
1765+
auto zero = builder.getIntegerAttr(signlessType, 0);
1766+
auto zeroVal = builder.create<mlir::arith::ConstantOp>(loc, zero);
1767+
val = doSignCast(builder, loc, val);
1768+
res = builder.create<mlir::arith::SubIOp>(loc, zeroVal, val);
1769+
res = doSignCast(builder, loc, res, type);
1770+
}
1771+
} else {
1772+
plier::reportError("Unhandled unop type");
1773+
}
1774+
return ctx.context.createVar(context, res);
1775+
}
1776+
17451777
static py::object strImpl(py::capsule /*context*/, py::capsule ssaVal) {
17461778
return py::str("Var: \"" + toStr(unwrapMlir<mlir::Value>(ssaVal)) + "\"");
17471779
}
@@ -1753,6 +1785,7 @@ static void setupPyVar(pybind11::handle var) {
17531785
py::setattr(var, "_len", py::cpp_function(&lenImpl));
17541786
py::setattr(var, "_getitem", py::cpp_function(&getitemImpl));
17551787
py::setattr(var, "_binop", py::cpp_function(&binopImpl));
1788+
py::setattr(var, "_unop", py::cpp_function(&unopImpl));
17561789
py::setattr(var, "_str", py::cpp_function(&strImpl));
17571790
}
17581791

0 commit comments

Comments
 (0)