Skip to content

Commit 431afd7

Browse files
XLS Teamcopybara-github
authored andcommitted
Automated rollback of commit 86f06e5.
PiperOrigin-RevId: 745561643
1 parent 86f06e5 commit 431afd7

File tree

6 files changed

+28
-80
lines changed

6 files changed

+28
-80
lines changed

xls/contrib/mlir/IR/xls_ops.cc

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -348,29 +348,14 @@ LogicalResult VectorizedCallOp::verifySymbolUses(
348348
}
349349
if (innerType != expectedInnerType) {
350350
return emitError(
351-
"return type in callee mismatch with scalarized call result: ")
351+
"Expected return type in callee to match scalarized "
352+
"call result type, got: ")
352353
<< innerType << " vs expected " << expectedInnerType;
353354
}
354355
}
355356
return success();
356357
}
357358

358-
LogicalResult VectorizedCallOp::canonicalize(VectorizedCallOp op,
359-
PatternRewriter& rewriter) {
360-
bool tensor =
361-
llvm::any_of(op.getOperands(),
362-
[](Value v) { return isa<TensorType>(v.getType()); }) ||
363-
llvm::any_of(op.getResultTypes(),
364-
[](Type t) { return isa<TensorType>(t); });
365-
if (!tensor) {
366-
// If no tensor operands or results, replace with plain call.
367-
rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
368-
op, op.getCalleeAttr(), op->getResultTypes(), op.getOperands());
369-
return success();
370-
}
371-
return failure();
372-
}
373-
374359
OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
375360
auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs());
376361
auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs());

xls/contrib/mlir/IR/xls_ops.td

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1556,8 +1556,7 @@ def Xls_CallDslxOp : Xls_Op<"call_dslx", [DeclareOpInterfaceMethods<MemoryEffect
15561556
StrAttr:$filename,
15571557
StrAttr:$function,
15581558
Variadic<AnyType>:$operands,
1559-
UnitAttr:$is_pure,
1560-
UnitAttr:$is_vector_call
1559+
UnitAttr:$is_pure
15611560
// TODO(jpienaar): Expand to include search paths, define symbols etc.
15621561
);
15631562
let results = (outs AnyType);
@@ -2371,8 +2370,6 @@ def VectorizedCallOp : Xls_Op<"vectorized_call",
23712370
let assemblyFormat = [{
23722371
$callee `(` operands `)` attr-dict `:` functional-type(operands, results)
23732372
}];
2374-
2375-
let hasCanonicalizeMethod = 1;
23762373
}
23772374

23782375
//===----------------------------------------------------------------------===//

xls/contrib/mlir/testdata/arith_to_xls.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// RUN: xls_opt -arith-to-xls -canonicalize %s 2>&1 | FileCheck %s
22

3+
34
// CHECK-LABEL: @constants
45
// CHECK-DAG: arith.constant 1
56
// CHECK-DAG: arith.constant dense

xls/contrib/mlir/testdata/integration/addf.mlir

Lines changed: 0 additions & 10 deletions
This file was deleted.

xls/contrib/mlir/transforms/arith_to_xls_patterns.td

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -52,24 +52,21 @@ class FCmpPat<string s, dag target> :
5252
Pat<(Arith_CmpFOp:$op ConstantEnumCase<Arith_CmpFPredicateAttr, s>, $lhs, $rhs, $fmf),
5353
target>;
5454

55-
defvar PureCall = ConstUnitAttr;
56-
defvar VectorCall = ConstUnitAttr;
57-
5855
def : FCmpPat<"oeq", (Xls_CallDslxOp (FloatLib $lhs), CS<"eq_2">,
59-
(variadic $lhs, $rhs), PureCall, VectorCall)>;
56+
(variadic $lhs, $rhs), ConstUnitAttr)>;
6057
def : FCmpPat<"ogt", (Xls_CallDslxOp (FloatLib $lhs), CS<"gt_2">,
61-
(variadic $lhs, $rhs), PureCall, VectorCall)>;
58+
(variadic $lhs, $rhs), ConstUnitAttr)>;
6259
def : FCmpPat<"oge", (Xls_CallDslxOp (FloatLib $lhs), CS<"gte_2">,
63-
(variadic $lhs, $rhs), PureCall, VectorCall)>;
60+
(variadic $lhs, $rhs), ConstUnitAttr)>;
6461
def : FCmpPat<"olt", (Xls_CallDslxOp (FloatLib $lhs), CS<"lt_2">,
65-
(variadic $lhs, $rhs), PureCall, VectorCall)>;
62+
(variadic $lhs, $rhs), ConstUnitAttr)>;
6663
def : FCmpPat<"ole", (Xls_CallDslxOp (FloatLib $lhs), CS<"lte_2">,
67-
(variadic $lhs, $rhs), PureCall, VectorCall)>;
64+
(variadic $lhs, $rhs), ConstUnitAttr)>;
6865

6966
// Define unordered(x, y) = is_nan(x + y)
7067
def : FCmpPat<"uno", (Xls_CallDslxOp (FloatLib $lhs), CS<"is_nan">,
7168
(variadic (Arith_AddFOp $lhs, $rhs, $fmf)),
72-
PureCall, VectorCall)>;
69+
ConstUnitAttr)>;
7370

7471
class OppositeFCmpPat<string s, string r> :
7572
Pat<(Arith_CmpFOp:$op ConstantEnumCase<Arith_CmpFPredicateAttr, s>, $lhs,
@@ -135,70 +132,69 @@ def : Pat<(Arith_ConstantOp NonIndexAttr:$val),
135132
defvar ExtLib = CS<"xls/contrib/mlir/stdlib/fp_ext_trunc.x">;
136133

137134
def : Pat<(Arith_NegFOp:$op $a, /*FastMathFlags=*/$_),
138-
(Xls_CallDslxOp (FloatLib $a), CS<"negate">, (variadic $a), PureCall, VectorCall)>;
135+
(Xls_CallDslxOp (FloatLib $a), CS<"negate">, (variadic $a), ConstUnitAttr)>;
139136

140137
// Emits a binary float library call.
141138
class FloatLibcall<Op Op, string Name> :
142139
Pat<(Op:$op $a, $b, /*FastMathFlags=*/$_),
143-
(Xls_CallDslxOp (FloatLib $a), CS<Name>, (variadic $a, $b),
144-
PureCall, VectorCall)>;
140+
(Xls_CallDslxOp (FloatLib $a), CS<Name>, (variadic $a, $b), ConstUnitAttr)>;
145141

146142
def : FloatLibcall<Arith_AddFOp, "add">;
147143
def : FloatLibcall<Arith_MulFOp, "mul">;
148144
def : FloatLibcall<Arith_SubFOp, "sub">;
149145

150146
def : Pat<(Arith_ExtFOp:$op $a, /*FastMathFlags=*/$_),
151-
(Xls_CallDslxOp ExtLib, CS<"ext">, (variadic $a), PureCall, VectorCall)>;
147+
(Xls_CallDslxOp ExtLib, CS<"ext">, (variadic $a), ConstUnitAttr)>;
152148

153149
def : Pat<(Arith_TruncFOp:$op $a, /*RoundingMode=*/$_, /*FastMathFlags=*/$_),
154-
(Xls_CallDslxOp ExtLib, CS<"trunc">, (variadic $a), PureCall, VectorCall),
150+
(Xls_CallDslxOp ExtLib, CS<"trunc">, (variadic $a), ConstUnitAttr),
155151
[(ScalarOrTensorOf<F32> $a)]>;
156152

157153
multiclass FloatToIntegralPat<Op fptoi, Op itofp, Op exti, string u> {
158154
def : Pat<(itofp:$op I32:$a),
159155
(Xls_CallDslxOp (FloatLib $op), CS<"from_"#u#"int32">, (variadic $a),
160-
PureCall, VectorCall),
156+
ConstUnitAttr),
161157
[(ScalarOrTensorOf<F32> $op)]>;
162158

163159
def : Pat<(itofp:$op I32:$a),
164160
(Xls_CallDslxOp (FloatLib $op), CS<"from_float32">,
165161
(variadic (itofp $a, (returnType "$_builder.getF32Type()"))),
166-
PureCall, VectorCall),
162+
ConstUnitAttr),
167163
[(ScalarOrTensorOf<BF16> $op)]>;
168164

169165
def : Pat<(itofp:$op I16:$a),
170166
(Xls_CallDslxOp (FloatLib $op), CS<"from_"#u#"int32">,
171167
(variadic (exti $a, (returnType "$_builder.getI32Type()"))),
172-
PureCall, VectorCall),
168+
ConstUnitAttr),
173169
[(ScalarOrTensorOf<F32> $op)]>;
174170

175171
def : Pat<(itofp:$op I16:$a),
176172
(Xls_CallDslxOp (FloatLib $op), CS<"from_float32">,
177173
(variadic (itofp $a, (returnType "$_builder.getF32Type()"))),
178-
PureCall, VectorCall),
174+
ConstUnitAttr),
179175
[(ScalarOrTensorOf<BF16> $op)]>;
180176

181177
def : Pat<(itofp:$op I8:$a),
182178
(Xls_CallDslxOp (FloatLib $op), CS<"from_"#u#"int8">, (variadic $a),
183-
PureCall, VectorCall),
179+
ConstUnitAttr),
184180
[(ScalarOrTensorOf<BF16> $op)]>;
185181

186182
def : Pat<(fptoi:$op F32:$a),
187183
(Xls_CallDslxOp (FloatLib $a), CS<"to_"#u#"int32">, (variadic $a),
188-
PureCall, VectorCall),
184+
ConstUnitAttr),
189185
[(ScalarOrTensorOf<I32> $op)]>;
190186

191187
def : Pat<(fptoi:$op BF16:$a),
192188
(Xls_CallDslxOp (FloatLib $a), CS<"to_"#u#"int16">, (variadic $a),
193-
PureCall, VectorCall),
189+
ConstUnitAttr),
194190
[(ScalarOrTensorOf<I16> $op)]>;
195191

196192
// TODO(jmolloy): to_int8 doesn't exist, so truncating the result of to_int16
197193
// seems like a reasonable approximation but I don't know if it's bit accurate.
198194
def : Pat<(fptoi:$op BF16:$a),
199195
(Arith_TruncIOp
200196
(Xls_CallDslxOp (FloatLib $a), CS<"to_"#u#"int16">, (variadic $a),
201-
PureCall, VectorCall,
197+
ConstUnitAttr,
202198
(returnType "$_builder.getI16Type()"))),
203199
[(ScalarOrTensorOf<I8> $op)]>;
204200
}
@@ -283,7 +279,7 @@ class MinMaxPatBase<dag Matcher, dag Predicate> : Pat<
283279

284280
class FPMinMaxPat<Op Op, string Name> : MinMaxPatBase<
285281
(Op:$op $a, $b, /*FastMathFlags=*/$_),
286-
(Xls_CallDslxOp (FloatLib $a), CS<Name>, (variadic $a, $b), PureCall, VectorCall,
282+
(Xls_CallDslxOp (FloatLib $a), CS<Name>, (variadic $a, $b), ConstUnitAttr,
287283
(returnType "boolLike(op)"))>;
288284

289285
def : FPMinMaxPat<Arith_MaximumFOp, "gt_2">;

xls/contrib/mlir/transforms/normalize_xls_calls.cc

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
#include <vector>
1818

1919
#include "llvm/include/llvm/ADT/DenseMap.h"
20-
#include "llvm/include/llvm/ADT/STLExtras.h"
21-
#include "llvm/include/llvm/ADT/SmallVector.h"
2220
#include "llvm/include/llvm/ADT/StringMap.h"
2321
#include "llvm/include/llvm/ADT/StringRef.h"
2422
#include "llvm/include/llvm/Support/Casting.h"
@@ -27,8 +25,6 @@
2725
#include "mlir/include/mlir/IR/Builders.h"
2826
#include "mlir/include/mlir/IR/BuiltinOps.h"
2927
#include "mlir/include/mlir/IR/SymbolTable.h"
30-
#include "mlir/include/mlir/IR/TypeRange.h"
31-
#include "mlir/include/mlir/IR/TypeUtilities.h"
3228
#include "mlir/include/mlir/IR/Visitors.h"
3329
#include "mlir/include/mlir/Pass/Pass.h" // IWYU pragma: keep
3430
#include "xls/contrib/mlir/IR/xls_ops.h"
@@ -78,19 +74,8 @@ void NormalizeXlsCallsPass::runOnOperation() {
7874
it->second.push_back(pkgImport);
7975
}
8076

81-
FunctionType newType;
82-
if (call.getIsVectorCall()) {
83-
auto scalarize = [](TypeRange r) {
84-
return llvm::to_vector(llvm::map_range(
85-
r, [](Type t) { return getElementTypeOrSelf(t); }));
86-
};
87-
newType = builder.getFunctionType(scalarize(call->getOperandTypes()),
88-
scalarize(call->getResultTypes()));
89-
} else {
90-
newType = builder.getFunctionType(call->getOperandTypes(),
91-
call->getResultTypes());
92-
}
93-
77+
auto newType = builder.getFunctionType(call->getOperandTypes(),
78+
call->getResultTypes());
9479
auto func = builder.create<mlir::func::FuncOp>(
9580
op->getLoc(),
9681
llvm::formatv("{0}_{1}", path.stem(), call.getFunction()).str(),
@@ -106,15 +91,9 @@ void NormalizeXlsCallsPass::runOnOperation() {
10691
}
10792

10893
OpBuilder b(op);
109-
Operation* fnCall;
110-
if (call.getIsVectorCall()) {
111-
fnCall = b.create<mlir::xls::VectorizedCallOp>(
112-
op->getLoc(), fIt->second.front(), call.getOperands());
113-
} else {
114-
fnCall = b.create<mlir::func::CallOp>(op->getLoc(), fIt->second.front(),
115-
call.getOperands());
116-
}
117-
op->replaceAllUsesWith(fnCall->getResults());
94+
auto fnCall = b.create<mlir::func::CallOp>(
95+
op->getLoc(), fIt->second.front(), call.getOperands());
96+
op->replaceAllUsesWith(fnCall.getResults());
11897
op->erase();
11998
}
12099
return WalkResult::advance();

0 commit comments

Comments
 (0)