Skip to content

Commit c6b7385

Browse files
committed
Revert "Update unstructured-to-memref pass to convert tts.scatter to linalg.generic (#227)"
This reverts commit ccba545.
1 parent f8dc27b commit c6b7385

14 files changed

+290
-351
lines changed

lib/Conversion/UnstructuredToMemref/UnstructuredToMemrefPass.cpp

Lines changed: 92 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ class PtrToUnrankedMemrefConverter : public TypeConverter {
5252
});
5353
addTargetMaterialization([&](OpBuilder &builder,
5454
UnrankedMemRefType resultType,
55-
ValueRange inputs, Location loc) -> Value {
55+
ValueRange inputs,
56+
Location loc) -> Value {
5657
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
5758
.getResult(0);
5859
});
@@ -157,7 +158,7 @@ struct ScalarStoreConverter : public OpConversionPattern<tts::ScatterOp> {
157158
}
158159
};
159160

160-
// Lowering an unstructured load op (gather) into a linalg.generic op.
161+
// Lowering an unstructured load op (gather) into a linalg.generic op
161162
struct GatherConverter : public OpConversionPattern<tts::GatherOp> {
162163
using OpConversionPattern<tts::GatherOp>::OpConversionPattern;
163164

@@ -170,109 +171,118 @@ struct GatherConverter : public OpConversionPattern<tts::GatherOp> {
170171
LogicalResult
171172
matchAndRewrite(tts::GatherOp gatherOp, OpAdaptor adaptor,
172173
ConversionPatternRewriter &rewriter) const override {
174+
173175
auto loc = gatherOp->getLoc();
174176

175177
auto ptr = adaptor.getPtr();
176178
auto offsetTensor = adaptor.getOffset();
177179
auto offsetType = dyn_cast<ShapedType>(offsetTensor.getType());
178180

179-
// This must be a scalar load, skip processing.
181+
// This must be a scalar load, skip processing
180182
if (!offsetType) {
181183
return failure();
182184
}
183185

184-
auto resultType =
186+
auto loadResultType =
185187
dyn_cast<RankedTensorType>(gatherOp.getResult().getType());
186188

187189
// Treat the base pointer (memref) as 1D because the offsets are all
188190
// relative to a single base pointer (already collapsed).
189-
auto baseMemref = rewriter
190-
.create<memref::CastOp>(
191-
loc,
192-
MemRefType::get({ShapedType::kDynamic},
193-
resultType.getElementType()),
194-
ptr)
195-
.getResult();
191+
auto baseMemref = rewriter.create<memref::CastOp>(
192+
loc,
193+
MemRefType::get({ShapedType::kDynamic},
194+
loadResultType.getElementType()),
195+
ptr);
196196

197197
auto baseTensor =
198198
rewriter
199199
.create<bufferization::ToTensorOp>(
200200
loc,
201201
RankedTensorType::get(
202202
SmallVector<int64_t>(1, ShapedType::kDynamic),
203-
resultType.getElementType()),
203+
loadResultType.getElementType()),
204204
baseMemref, true /* restrict */, false /* writable */)
205205
.getResult();
206206

207207
// The linalg.generic op should have the following inputs:
208-
// - the offset tensor.
209-
// - an optional mask tensor if the gather op contains mask.
208+
// - the offset tensor
209+
// - an optional mask tensor if the load op contains mask
210210
SmallVector<Value> inputs{offsetTensor};
211211

212212
if (gatherOp.getMask()) {
213213
inputs.push_back(gatherOp.getMask());
214214
}
215215

216-
auto emptyTensor = rewriter
217-
.create<tensor::EmptyOp>(loc, resultType.getShape(),
218-
resultType.getElementType())
219-
.getResult();
216+
auto emptyTensor =
217+
rewriter
218+
.create<tensor::EmptyOp>(loc, loadResultType.getShape(),
219+
loadResultType.getElementType())
220+
.getResult();
220221

221-
// Affine maps for the inputs and one additional output.
222+
// Affine maps for the inputs and output
223+
// If no mask is used, 2 affine maps are generated; one for the input offset
224+
// tensor, the other for the output tensor.
225+
// If mask is used, the first 2 maps are for the offset and mask tensors
226+
// while the last map is for the output tensor.
222227
SmallVector<AffineMap> affineMaps(
223-
inputs.size() + 1,
224-
rewriter.getMultiDimIdentityMap(resultType.getRank()));
225-
226-
// All iterator types are parallel.
227-
SmallVector<utils::IteratorType> iteratorTypes(
228-
resultType.getRank(), utils::IteratorType::parallel);
228+
gatherOp.getMask() ? 3 : 2,
229+
rewriter.getMultiDimIdentityMap(loadResultType.getRank()));
229230

230231
auto genericOp = rewriter.create<linalg::GenericOp>(
231-
loc, TypeRange{resultType}, inputs, ValueRange{emptyTensor}, affineMaps,
232-
iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
233-
auto getValueAtIndex = [baseTensor](OpBuilder &b, Location loc,
234-
Value index) -> Value {
232+
loc, SmallVector<Type>({loadResultType}), inputs,
233+
ValueRange{emptyTensor}, affineMaps,
234+
SmallVector<utils::IteratorType>(loadResultType.getRank(),
235+
utils::IteratorType::parallel),
236+
[&](OpBuilder &b, Location loc, ValueRange args) {
237+
auto getValueAtIndex = [baseTensor](Value indexValue, Location loc,
238+
OpBuilder &b) -> Value {
235239
Value index0 =
236-
b.create<arith::IndexCastOp>(loc, b.getIndexType(), index);
240+
b.create<arith::IndexCastOp>(loc, b.getIndexType(), indexValue);
237241

238242
return b.create<tensor::ExtractOp>(loc, baseTensor,
239243
ValueRange{index0});
240244
};
241245

242-
auto offset = args[0];
243-
244246
if (!gatherOp.getMask()) {
245247
// If there is no mask, simply extract the current element from the
246248
// base tensor and use it as the yield value.
247-
auto loadValue = getValueAtIndex(b, loc, offset);
248-
b.create<linalg::YieldOp>(loc, loadValue);
249+
auto loadValue = getValueAtIndex(args[0], loc, rewriter);
250+
rewriter.create<linalg::YieldOp>(loc, loadValue);
249251
} else {
250252
// If the mask value is truthy, the current element is loaded from
251253
// the base tensor using its offset. Otherwise, if `other` is
252254
// present, yield `other`. If `other` is not present, a default
253255
// value of 0 is used.
254256
auto mask = args[1];
255-
auto ifOp = b.create<scf::IfOp>(
257+
auto ifOp = rewriter.create<scf::IfOp>(
256258
loc, mask,
257259
[&](OpBuilder &b, Location loc) {
258-
// Truthy case, load from the index.
259-
auto value = getValueAtIndex(b, loc, offset);
260-
b.create<scf::YieldOp>(loc, value);
260+
// Truthy case, load from the index
261+
auto loadValue = getValueAtIndex(args[0], loc, b);
262+
b.create<scf::YieldOp>(loc, loadValue);
261263
},
262264
[&](OpBuilder &b, Location loc) {
263-
// Falsy case, yield `other` or 0 as the default value.
265+
// Falsy case, yield `other` or 0 as the default value
264266
if (gatherOp.getOther()) {
265267
b.create<scf::YieldOp>(loc, gatherOp.getOther());
266268
} else {
267-
auto elemType = resultType.getElementType();
268-
auto zeroAttr = b.getZeroAttr(elemType);
269-
assert(zeroAttr && "unexpected element type");
270-
Value extract = b.create<arith::ConstantOp>(loc, zeroAttr);
269+
auto elemType = baseTensor.getType().getElementType();
270+
Value extract;
271+
if (isa<IntegerType>(elemType)) {
272+
extract = rewriter.create<arith::ConstantOp>(
273+
loc, b.getIntegerAttr(elemType, 0));
274+
} else if (isa<FloatType>(elemType)) {
275+
extract = rewriter.create<arith::ConstantOp>(
276+
loc, b.getFloatAttr(elemType, 0));
277+
} else {
278+
elemType.dump();
279+
llvm_unreachable("unexpected type");
280+
}
271281
b.create<scf::YieldOp>(loc, extract);
272282
}
273283
});
274284

275-
b.create<linalg::YieldOp>(loc, ifOp->getResult(0));
285+
rewriter.create<linalg::YieldOp>(loc, ifOp->getResult(0));
276286
}
277287
});
278288

@@ -282,7 +292,7 @@ struct GatherConverter : public OpConversionPattern<tts::GatherOp> {
282292
}
283293
};
284294

285-
// Lowering an unstructured store op (scatter) into a linalg.generic op.
295+
// Lowering an unstructured store op (scatter) into an affine loop nest
286296
struct ScatterConverter : public OpConversionPattern<tts::ScatterOp> {
287297
using OpConversionPattern<tts::ScatterOp>::OpConversionPattern;
288298

@@ -299,81 +309,57 @@ struct ScatterConverter : public OpConversionPattern<tts::ScatterOp> {
299309

300310
auto ptr = adaptor.getPtr();
301311
auto offsetTensor = adaptor.getOffset();
302-
auto valueTensor = adaptor.getValue();
303312
auto offsetType = dyn_cast<ShapedType>(offsetTensor.getType());
304313

305-
// This must be a scalar store, skip processing.
314+
// This must be a scalar store, skip processing
306315
if (!offsetType) {
307316
return failure();
308317
}
309318

310-
auto valueType = dyn_cast<RankedTensorType>(scatterOp.getValue().getType());
311-
312-
// Treat the base pointer (memref) as 1D because the offsets are all
313-
// relative to a single base pointer (already collapsed).
314-
auto baseMemref =
315-
rewriter
316-
.create<memref::CastOp>(loc,
317-
MemRefType::get({ShapedType::kDynamic},
318-
valueType.getElementType()),
319-
ptr)
320-
.getResult();
321-
322-
// The linalg.generic op should have the following inputs:
323-
// - the offset tensor.
324-
// - the value tensor.
325-
// - an optional mask tensor if the scatter op contains mask.
326-
SmallVector<Value> inputs{offsetTensor, valueTensor};
319+
auto resultType =
320+
dyn_cast<RankedTensorType>(scatterOp.getValue().getType());
327321

328-
if (scatterOp.getMask()) {
329-
inputs.push_back(scatterOp.getMask());
322+
auto storeMemref = rewriter.create<memref::CastOp>(
323+
loc,
324+
MemRefType::get({ShapedType::kDynamic}, resultType.getElementType()),
325+
ptr);
326+
327+
auto ip = rewriter.saveInsertionPoint();
328+
329+
SmallVector<Value> ivs;
330+
for (auto dim : resultType.getShape()) {
331+
auto ub =
332+
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(dim));
333+
auto forOp = rewriter.create<affine::AffineForOp>(loc, 0, dim);
334+
ivs.push_back(forOp.getInductionVar());
335+
rewriter.setInsertionPointToStart(forOp.getBody());
330336
}
331337

332-
// Affine maps for the inputs.
333-
SmallVector<AffineMap> affineMaps(
334-
inputs.size(), rewriter.getMultiDimIdentityMap(valueType.getRank()));
335-
336-
// All iterator types are parallel.
337-
SmallVector<utils::IteratorType> iteratorTypes(
338-
valueType.getRank(), utils::IteratorType::parallel);
339-
340-
rewriter.setInsertionPoint(scatterOp);
341-
342-
auto genericOp = rewriter.create<linalg::GenericOp>(
343-
loc, TypeRange{}, inputs, ValueRange{}, affineMaps, iteratorTypes,
344-
[&](OpBuilder &b, Location loc, ValueRange args) {
345-
auto storeValueAtIndex = [baseMemref](OpBuilder &b, Location loc,
346-
Value index, Value value) {
347-
Value index0 =
348-
b.create<arith::IndexCastOp>(loc, b.getIndexType(), index);
349-
350-
b.create<memref::StoreOp>(loc, value, baseMemref,
351-
ValueRange{index0});
352-
};
338+
if (scatterOp.getMask()) {
339+
// Mask case, only store the value if the mask value at `ivs` is truthy
340+
auto maskValue =
341+
rewriter.create<tensor::ExtractOp>(loc, scatterOp.getMask(), ivs);
353342

354-
auto offset = args[0];
355-
auto value = args[1];
343+
auto ifOp = rewriter.create<scf::IfOp>(loc, maskValue,
344+
false /* withElseRegion */);
356345

357-
if (!scatterOp.getMask()) {
358-
// If there is no mask, simply insert the current value to the
359-
// base memref using its offset.
360-
storeValueAtIndex(b, loc, offset, value);
361-
} else {
362-
// If the mask value is truthy, insert the current value to the
363-
// the base memref using its offset. Otherwise, noop.
364-
auto mask = args[2];
365-
auto ifOp =
366-
b.create<scf::IfOp>(loc, mask, [&](OpBuilder &b, Location loc) {
367-
storeValueAtIndex(b, loc, offset, value);
368-
b.create<scf::YieldOp>(loc);
369-
});
370-
}
346+
rewriter.setInsertionPointToStart(
347+
&ifOp.getThenRegion().getBlocks().front());
348+
}
371349

372-
b.create<linalg::YieldOp>(loc);
373-
});
350+
// Generate ops to store the value at each index. Note that with masking,
351+
// these ops are created in the `if` block generated above.
352+
auto offsetValue =
353+
rewriter.create<tensor::ExtractOp>(loc, offsetTensor, ivs);
354+
auto storeValue =
355+
rewriter.create<tensor::ExtractOp>(loc, scatterOp.getValue(), ivs);
356+
Value storeIndex = rewriter.create<arith::IndexCastOp>(
357+
loc, rewriter.getIndexType(), offsetValue);
358+
rewriter.create<memref::StoreOp>(loc, storeValue, storeMemref, storeIndex);
374359

360+
// Finalize
375361
rewriter.eraseOp(scatterOp);
376-
362+
rewriter.restoreInsertionPoint(ip);
377363
return success();
378364
}
379365
};

test/Conversion/StructuredToMemref/convert_1d_elemwise_arith_binary.mlir

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,7 @@ module {
3232
// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)>
3333
// CHECK-LABEL: func.func @kernel
3434
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: memref<*xf32>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32) {
35-
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32
36-
// CHECK-DAG: [[VAR_empty_offsets_:%.+]] = tensor.empty() : tensor<1024xi32>
37-
// CHECK-DAG: [[VAR_zero_offsets_:%.+]] = linalg.fill ins([[CST_0_]] : i32) outs([[VAR_empty_offsets_]] : tensor<1024xi32>) -> tensor<1024xi32>
35+
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
3836
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1]>>
3937
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1]>>
4038
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<1024xf32>
@@ -75,11 +73,9 @@ module {
7573
// CHECK: linalg.yield [[VAR_9_5_]] : f32
7674
// CHECK: } -> tensor<1024xf32>
7775
// CHECK: [[VAR_cast_:%.+]] = memref.cast [[PARAM_2_]] : memref<*xf32> to memref<?xf32>
78-
// CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_zero_offsets_]], [[VAR_8_]] : tensor<1024xi32>, tensor<1024xf32>) {
79-
// CHECK: ^bb0([[IN_18_:%.+]]: i32, [[IN_19_:%.+]]: f32):
80-
// CHECK: [[VAR_10_:%.+]] = arith.index_cast [[IN_18_]] : i32 to index
81-
// CHECK: memref.store [[IN_19_]], [[VAR_cast_]]{{.}}[[VAR_10_]]{{.}} : memref<?xf32>
82-
// CHECK: linalg.yield
76+
// CHECK: affine.for [[I_0_:%.+]] = 0 to 1024 {
77+
// CHECK: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_8_]]{{.}}[[I_0_]]{{.}} : tensor<1024xf32>
78+
// CHECK: memref.store [[VAR_extracted_]], [[VAR_cast_]]{{.}}[[CST_0_]]{{.}} : memref<?xf32>
8379
// CHECK: }
8480
// CHECK: return
8581
// CHECK: }

test/Conversion/StructuredToMemref/convert_1d_elemwise_arith_ternary.mlir

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@ module {
3131
// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)>
3232
// CHECK-LABEL: func.func @kernel
3333
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xi1>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: memref<*xf32>, [[PARAM_3_:%.+]]: memref<*xf32>, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32) {
34-
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32
35-
// CHECK-DAG: [[VAR_empty_offsets_:%.+]] = tensor.empty() : tensor<1024xi32>
36-
// CHECK-DAG: [[VAR_zero_offsets_:%.+]] = linalg.fill ins([[CST_0_]] : i32) outs([[VAR_empty_offsets_]] : tensor<1024xi32>) -> tensor<1024xi32>
34+
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
3735
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [1024], strides: [1] : memref<*xi1> to memref<1024xi1, strided<[1]>>
3836
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1]>>
3937
// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: [0], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1]>>
@@ -52,11 +50,9 @@ module {
5250
// CHECK: linalg.yield [[VAR_4_]] : f32
5351
// CHECK: } -> tensor<1024xf32>
5452
// CHECK: [[VAR_cast_:%.+]] = memref.cast [[PARAM_3_]] : memref<*xf32> to memref<?xf32>
55-
// CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_zero_offsets_]], [[VAR_3_]] : tensor<1024xi32>, tensor<1024xf32>) {
56-
// CHECK: ^bb0([[IN_3_:%.+]]: i32, [[IN_4_:%.+]]: f32):
57-
// CHECK: [[VAR_5_:%.+]] = arith.index_cast [[IN_3_]] : i32 to index
58-
// CHECK: memref.store [[IN_4_]], [[VAR_cast_]]{{.}}[[VAR_5_]]{{.}} : memref<?xf32>
59-
// CHECK: linalg.yield
53+
// CHECK: affine.for [[I_0_:%.+]] = 0 to 1024 {
54+
// CHECK: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_3_]]{{.}}[[I_0_]]{{.}} : tensor<1024xf32>
55+
// CHECK: memref.store [[VAR_extracted_]], [[VAR_cast_]]{{.}}[[CST_0_]]{{.}} : memref<?xf32>
6056
// CHECK: }
6157
// CHECK: return
6258
// CHECK: }

0 commit comments

Comments
 (0)