Skip to content

Commit 928b5b0

Browse files
author
Peiming Liu
committed
[mlir][sparse] add conversion rules for storage_get/set/callOp
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D133175
1 parent 46b293c commit 928b5b0

File tree

3 files changed

+161
-16
lines changed

3 files changed

+161
-16
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,10 +217,12 @@ struct SparseTensorStorageExpansionPass
217217
target.addDynamicallyLegalOp<func::ReturnOp>([&](func::ReturnOp op) {
218218
return converter.isLegal(op.getOperandTypes());
219219
});
220+
// We generate UnrealizedConversionCastOp to intermix tuples and a
221+
// list of types.
222+
target.addLegalOp<UnrealizedConversionCastOp>();
220223
// Populate with rules and apply rewriting rules.
221224
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
222225
converter);
223-
populateCallOpTypeConversionPattern(patterns, converter);
224226
scf::populateSCFStructuralTypeConversionsAndLegality(converter, patterns,
225227
target);
226228
populateSparseTensorStorageExpansionPatterns(converter, patterns);

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorStorageExpansion.cpp

Lines changed: 119 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,69 @@ convertSparseTensorStorageTuple(Type t, SmallVectorImpl<Type> &result) {
4141
return llvm::None;
4242
}
4343

44+
/// Flatten a list of operands that may contain tuples.
45+
static void flattenOperands(ValueRange operands,
46+
SmallVectorImpl<Value> &flattened) {
47+
// In case of
48+
// tuple<a, b>, c, tuple<d, e>
49+
// ==>
50+
// a, b, c, d, e
51+
for (auto operand : operands) {
52+
if (auto cast =
53+
dyn_cast<UnrealizedConversionCastOp>(operand.getDefiningOp());
54+
cast && cast->getResultTypes()[0].isa<TupleType>())
55+
// An unrealized_conversion_cast will be inserted by type converter to
56+
// inter-mix the gap between 1:N conversion between tuple and types.
57+
// In this case, take the operands in the cast and replace the tuple
58+
// output with the flattened type array.
59+
flattened.append(cast.getOperands().begin(), cast.getOperands().end());
60+
else
61+
flattened.push_back(operand);
62+
}
63+
}
4464
//===----------------------------------------------------------------------===//
4565
// Conversion rules.
4666
//===----------------------------------------------------------------------===//
4767

68+
/// Sparse tensor storage conversion rule for sparse_tensor::storage_get.
69+
class SparseStorageGetConverter : public OpConversionPattern<StorageGetOp> {
70+
public:
71+
using OpConversionPattern::OpConversionPattern;
72+
LogicalResult
73+
matchAndRewrite(StorageGetOp op, OpAdaptor adaptor,
74+
ConversionPatternRewriter &rewriter) const override {
75+
auto castOp =
76+
cast<UnrealizedConversionCastOp>(adaptor.getStorage().getDefiningOp());
77+
uint64_t idx = op.getIdx().getZExtValue();
78+
assert(idx < castOp.getOperands().size());
79+
80+
rewriter.replaceOp(op, castOp.getOperand(idx));
81+
return success();
82+
}
83+
};
84+
85+
/// Sparse tensor storage conversion rule for sparse_tensor::storage_set.
86+
class SparseStorageSetConverter : public OpConversionPattern<StorageSetOp> {
87+
public:
88+
using OpConversionPattern::OpConversionPattern;
89+
LogicalResult
90+
matchAndRewrite(StorageSetOp op, OpAdaptor adaptor,
91+
ConversionPatternRewriter &rewriter) const override {
92+
auto castOp =
93+
cast<UnrealizedConversionCastOp>(adaptor.getStorage().getDefiningOp());
94+
uint64_t idx = op.getIdx().getZExtValue();
95+
96+
SmallVector<Value, 8> values(castOp.getOperands());
97+
assert(idx < values.size());
98+
99+
// Updates the corresponding element.
100+
values[idx] = adaptor.getValue();
101+
rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
102+
op, TypeRange{op.getType()}, values);
103+
return success();
104+
}
105+
};
106+
48107
/// Sparse tensor storage conversion rule for returns.
49108
class SparseStorageReturnConverter
50109
: public OpConversionPattern<func::ReturnOp> {
@@ -54,24 +113,69 @@ class SparseStorageReturnConverter
54113
matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
55114
ConversionPatternRewriter &rewriter) const override {
56115
SmallVector<Value, 8> flattened;
57-
for (auto operand : adaptor.getOperands()) {
58-
if (auto cast =
59-
dyn_cast<UnrealizedConversionCastOp>(operand.getDefiningOp());
60-
cast && cast->getResultTypes()[0].isa<TupleType>())
61-
// An unrealized_conversion_cast will be inserted by type converter to
62-
// inter-mix the gap between 1:N conversion between tuple and types.
63-
// In this case, take the operands in the cast and replace the tuple
64-
// output with the flattened type array.
65-
flattened.append(cast.getOperands().begin(), cast.getOperands().end());
66-
else
67-
flattened.push_back(operand);
68-
}
116+
flattenOperands(adaptor.getOperands(), flattened);
69117
// Create a return with the flattened value extracted from tuple.
70118
rewriter.replaceOpWithNewOp<func::ReturnOp>(op, flattened);
71119
return success();
72120
}
73121
};
74122

123+
/// Sparse tensor storage conversion rule for calls.
124+
class SparseStorageCallConverter : public OpConversionPattern<func::CallOp> {
125+
public:
126+
// The default CallOp converter can not handle 1:N type conversion properly
127+
using OpConversionPattern::OpConversionPattern;
128+
LogicalResult
129+
matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
130+
ConversionPatternRewriter &rewriter) const override {
131+
Location loc = op.getLoc();
132+
// In case of:
133+
// tuple(a, b), f, tuple(c, d) = call @foo(...)
134+
// ==>
135+
// a, b, f, c, d = call @foo(...)
136+
// cast(a, b)->tuple, f, cast(c,d)->tuple
137+
SmallVector<Type, 8> finalRetTy;
138+
if (failed(typeConverter->convertTypes(op.getResultTypes(), finalRetTy)))
139+
return failure();
140+
141+
// (1) Genereates new call with flattened return value.
142+
SmallVector<Value, 8> flattened;
143+
flattenOperands(adaptor.getOperands(), flattened);
144+
auto newCall = rewriter.create<func::CallOp>(loc, op.getCallee(),
145+
finalRetTy, flattened);
146+
147+
// (2) Create cast operation for tuple returns.
148+
SmallVector<Value, 4> castedRet;
149+
// Tracks the offset of current return value (of the orignal call)
150+
// relative to the new call (after tuple flattening);
151+
unsigned retOffset = 0;
152+
for (auto ret : op.getResults()) {
153+
assert(retOffset < newCall.getNumResults());
154+
auto tupleRet = ret.getType().dyn_cast<TupleType>();
155+
if (tupleRet) {
156+
auto tupleSize = tupleRet.size();
157+
// NOTE: The range is computed under the assumption of non-recursive
158+
// tuple type.
159+
ValueRange tupleElem(iterator_range<ResultRange::iterator>(
160+
newCall.result_begin() + retOffset,
161+
newCall.result_begin() + retOffset + tupleSize));
162+
auto castOp = rewriter.create<UnrealizedConversionCastOp>(
163+
loc, TypeRange({tupleRet}), tupleElem);
164+
castedRet.push_back(castOp.getResult(0));
165+
retOffset += tupleSize;
166+
} else {
167+
// If this not a tuple, simply add it into returned values.
168+
castedRet.push_back(ret);
169+
retOffset++;
170+
}
171+
}
172+
173+
assert(castedRet.size() == op.getNumResults());
174+
rewriter.replaceOp(op, castedRet);
175+
return success();
176+
}
177+
};
178+
75179
} // namespace
76180

77181
//===----------------------------------------------------------------------===//
@@ -91,6 +195,7 @@ mlir::SparseTensorStorageTupleExpander::SparseTensorStorageTupleExpander() {
91195
/// to expand compounded sparse tensor tuples.
92196
void mlir::populateSparseTensorStorageExpansionPatterns(
93197
TypeConverter &typeConverter, RewritePatternSet &patterns) {
94-
patterns.add<SparseStorageReturnConverter>(typeConverter,
95-
patterns.getContext());
198+
patterns.add<SparseStorageGetConverter, SparseStorageSetConverter,
199+
SparseStorageReturnConverter, SparseStorageCallConverter>(
200+
typeConverter, patterns.getContext());
96201
}

mlir/test/Dialect/SparseTensor/sparse_tensor_storage.mlir

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -sparse-tensor-storage-expansion | FileCheck %s
1+
// RUN: mlir-opt %s -sparse-tensor-storage-expansion -cse | FileCheck %s
22

33
// CHECK-LABEL: func @sparse_storage_expand(
44
// CHECK-SAME: %[[TMP_arg0:.*0]]: memref<?xf64>,
@@ -9,3 +9,41 @@ func.func @sparse_storage_expand(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>
99
-> tuple<memref<?xf64>, memref<?xf64>, f64> {
1010
return %arg0 : tuple<memref<?xf64>, memref<?xf64>, f64>
1111
}
12+
13+
// CHECK-LABEL: func @call_sparse_storage_expand(
14+
// CHECK-SAME: %[[TMP_arg0:.*0]]: memref<?xf64>,
15+
// CHECK-SAME: %[[TMP_arg1:.*1]]: memref<?xf64>,
16+
// CHECK-SAME: %[[TMP_arg2:.*]]: f64)
17+
// CHECK: %[[TMP_0:.*]]:3 = call @sparse_storage_expand(%[[TMP_arg0]], %[[TMP_arg1]], %[[TMP_arg2]])
18+
// CHECK: return %[[TMP_0]]#0, %[[TMP_0]]#1, %[[TMP_0]]#2 : memref<?xf64>, memref<?xf64>, f64
19+
func.func @call_sparse_storage_expand(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>)
20+
-> tuple<memref<?xf64>, memref<?xf64>, f64> {
21+
%1 = call @sparse_storage_expand(%arg0) : (tuple<memref<?xf64>, memref<?xf64>, f64>) ->
22+
tuple<memref<?xf64>, memref<?xf64>, f64>
23+
return %1 : tuple<memref<?xf64>, memref<?xf64>, f64>
24+
}
25+
26+
// CHECK-LABEL: func @sparse_storage_get(
27+
// CHECK-SAME: %[[TMP_arg0:.*0]]: memref<?xf64>,
28+
// CHECK-SAME: %[[TMP_arg1:.*1]]: memref<?xf64>,
29+
// CHECK-SAME: %[[TMP_arg2:.*]]: f64)
30+
// CHECK: return %[[TMP_arg0]] : memref<?xf64>
31+
func.func @sparse_storage_get(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>) -> memref<?xf64> {
32+
%0 = sparse_tensor.storage_get %arg0[0]
33+
: tuple<memref<?xf64>, memref<?xf64>, f64> to memref<?xf64>
34+
return %0 : memref<?xf64>
35+
}
36+
37+
// CHECK-LABEL: func @sparse_storage_set(
38+
// CHECK-SAME: %[[TMP_arg0:.*0]]: memref<?xf64>,
39+
// CHECK-SAME: %[[TMP_arg1:.*1]]: memref<?xf64>,
40+
// CHECK-SAME: %[[TMP_arg2:.*]]: f64,
41+
// CHECK-SAME: %[[TMP_arg3:.*]]: memref<?xf64>)
42+
// CHECK: return %[[TMP_arg3]], %[[TMP_arg1]], %[[TMP_arg2]] : memref<?xf64>, memref<?xf64>, f64
43+
func.func @sparse_storage_set(%arg0: tuple<memref<?xf64>, memref<?xf64>, f64>,
44+
%arg1: memref<?xf64>) -> tuple<memref<?xf64>, memref<?xf64>, f64> {
45+
%0 = sparse_tensor.storage_set %arg0[0], %arg1
46+
: tuple<memref<?xf64>, memref<?xf64>, f64>, memref<?xf64> to
47+
tuple<memref<?xf64>, memref<?xf64>, f64>
48+
return %0 : tuple<memref<?xf64>, memref<?xf64>, f64>
49+
}

0 commit comments

Comments
 (0)