Skip to content

Commit d0fe5a9

Browse files
author
Xiaoran Weng
committed
Update unstructured-to-memref pass to convert tts.scatter to linalg.generic
1 parent 6f718b7 commit d0fe5a9

14 files changed

+351
-290
lines changed

lib/Conversion/UnstructuredToMemref/UnstructuredToMemrefPass.cpp

Lines changed: 106 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ class PtrToUnrankedMemrefConverter : public TypeConverter {
5252
});
5353
addTargetMaterialization([&](OpBuilder &builder,
5454
UnrankedMemRefType resultType,
55-
ValueRange inputs,
56-
Location loc) -> Value {
55+
ValueRange inputs, Location loc) -> Value {
5756
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
5857
.getResult(0);
5958
});
@@ -158,7 +157,7 @@ struct ScalarStoreConverter : public OpConversionPattern<tts::ScatterOp> {
158157
}
159158
};
160159

161-
// Lowering an unstructured load op (gather) into a linalg.generic op
160+
// Lowering an unstructured load op (gather) into a linalg.generic op.
162161
struct GatherConverter : public OpConversionPattern<tts::GatherOp> {
163162
using OpConversionPattern<tts::GatherOp>::OpConversionPattern;
164163

@@ -171,118 +170,109 @@ struct GatherConverter : public OpConversionPattern<tts::GatherOp> {
171170
LogicalResult
172171
matchAndRewrite(tts::GatherOp gatherOp, OpAdaptor adaptor,
173172
ConversionPatternRewriter &rewriter) const override {
174-
175173
auto loc = gatherOp->getLoc();
176174

177175
auto ptr = adaptor.getPtr();
178176
auto offsetTensor = adaptor.getOffset();
179177
auto offsetType = dyn_cast<ShapedType>(offsetTensor.getType());
180178

181-
// This must be a scalar load, skip processing
179+
// This must be a scalar load, skip processing.
182180
if (!offsetType) {
183181
return failure();
184182
}
185183

186-
auto loadResultType =
184+
auto resultType =
187185
dyn_cast<RankedTensorType>(gatherOp.getResult().getType());
188186

189187
// Treat the base pointer (memref) as 1D because the offsets are all
190188
// relative to a single base pointer (already collapsed).
191-
auto baseMemref = rewriter.create<memref::CastOp>(
192-
loc,
193-
MemRefType::get({ShapedType::kDynamic},
194-
loadResultType.getElementType()),
195-
ptr);
189+
auto baseMemref = rewriter
190+
.create<memref::CastOp>(
191+
loc,
192+
MemRefType::get({ShapedType::kDynamic},
193+
resultType.getElementType()),
194+
ptr)
195+
.getResult();
196196

197197
auto baseTensor =
198198
rewriter
199199
.create<bufferization::ToTensorOp>(
200200
loc,
201201
RankedTensorType::get(
202202
SmallVector<int64_t>(1, ShapedType::kDynamic),
203-
loadResultType.getElementType()),
203+
resultType.getElementType()),
204204
baseMemref, true /* restrict */, false /* writable */)
205205
.getResult();
206206

207207
// The linalg.generic op should have the following inputs:
208-
// - the offset tensor
209-
// - an optional mask tensor if the load op contains mask
208+
// - the offset tensor.
209+
// - an optional mask tensor if the gather op contains mask.
210210
SmallVector<Value> inputs{offsetTensor};
211211

212212
if (gatherOp.getMask()) {
213213
inputs.push_back(gatherOp.getMask());
214214
}
215215

216-
auto emptyTensor =
217-
rewriter
218-
.create<tensor::EmptyOp>(loc, loadResultType.getShape(),
219-
loadResultType.getElementType())
220-
.getResult();
216+
auto emptyTensor = rewriter
217+
.create<tensor::EmptyOp>(loc, resultType.getShape(),
218+
resultType.getElementType())
219+
.getResult();
221220

222-
// Affine maps for the inputs and output
223-
// If no mask is used, 2 affine maps are generated; one for the input offset
224-
// tensor, the other for the output tensor.
225-
// If mask is used, the first 2 maps are for the offset and mask tensors
226-
// while the last map is for the output tensor.
221+
// Affine maps for the inputs and one additional output.
227222
SmallVector<AffineMap> affineMaps(
228-
gatherOp.getMask() ? 3 : 2,
229-
rewriter.getMultiDimIdentityMap(loadResultType.getRank()));
223+
inputs.size() + 1,
224+
rewriter.getMultiDimIdentityMap(resultType.getRank()));
225+
226+
// All iterator types are parallel.
227+
SmallVector<utils::IteratorType> iteratorTypes(
228+
resultType.getRank(), utils::IteratorType::parallel);
230229

231230
auto genericOp = rewriter.create<linalg::GenericOp>(
232-
loc, SmallVector<Type>({loadResultType}), inputs,
233-
ValueRange{emptyTensor}, affineMaps,
234-
SmallVector<utils::IteratorType>(loadResultType.getRank(),
235-
utils::IteratorType::parallel),
236-
[&](OpBuilder &b, Location loc, ValueRange args) {
237-
auto getValueAtIndex = [baseTensor](Value indexValue, Location loc,
238-
OpBuilder &b) -> Value {
231+
loc, TypeRange{resultType}, inputs, ValueRange{emptyTensor}, affineMaps,
232+
iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
233+
auto getValueAtIndex = [baseTensor](OpBuilder &b, Location loc,
234+
Value index) -> Value {
239235
Value index0 =
240-
b.create<arith::IndexCastOp>(loc, b.getIndexType(), indexValue);
236+
b.create<arith::IndexCastOp>(loc, b.getIndexType(), index);
241237

242238
return b.create<tensor::ExtractOp>(loc, baseTensor,
243239
ValueRange{index0});
244240
};
245241

242+
auto offset = args[0];
243+
246244
if (!gatherOp.getMask()) {
247245
// If there is no mask, simply extract the current element from the
248246
// base tensor and use it as the yield value.
249-
auto loadValue = getValueAtIndex(args[0], loc, rewriter);
250-
rewriter.create<linalg::YieldOp>(loc, loadValue);
247+
auto loadValue = getValueAtIndex(b, loc, offset);
248+
b.create<linalg::YieldOp>(loc, loadValue);
251249
} else {
252250
// If the mask value is truthy, the current element is loaded from
253251
// the base tensor using its offset. Otherwise, if `other` is
254252
// present, yield `other`. If `other` is not present, a default
255253
// value of 0 is used.
256254
auto mask = args[1];
257-
auto ifOp = rewriter.create<scf::IfOp>(
255+
auto ifOp = b.create<scf::IfOp>(
258256
loc, mask,
259257
[&](OpBuilder &b, Location loc) {
260-
// Truthy case, load from the index
261-
auto loadValue = getValueAtIndex(args[0], loc, b);
262-
b.create<scf::YieldOp>(loc, loadValue);
258+
// Truthy case, load from the index.
259+
auto value = getValueAtIndex(b, loc, offset);
260+
b.create<scf::YieldOp>(loc, value);
263261
},
264262
[&](OpBuilder &b, Location loc) {
265-
// Falsy case, yield `other` or 0 as the default value
263+
// Falsy case, yield `other` or 0 as the default value.
266264
if (gatherOp.getOther()) {
267265
b.create<scf::YieldOp>(loc, gatherOp.getOther());
268266
} else {
269-
auto elemType = baseTensor.getType().getElementType();
270-
Value extract;
271-
if (isa<IntegerType>(elemType)) {
272-
extract = rewriter.create<arith::ConstantOp>(
273-
loc, b.getIntegerAttr(elemType, 0));
274-
} else if (isa<FloatType>(elemType)) {
275-
extract = rewriter.create<arith::ConstantOp>(
276-
loc, b.getFloatAttr(elemType, 0));
277-
} else {
278-
elemType.dump();
279-
llvm_unreachable("unexpected type");
280-
}
267+
auto elemType = resultType.getElementType();
268+
auto zeroAttr = b.getZeroAttr(elemType);
269+
assert(zeroAttr && "unexpected element type");
270+
Value extract = b.create<arith::ConstantOp>(loc, zeroAttr);
281271
b.create<scf::YieldOp>(loc, extract);
282272
}
283273
});
284274

285-
rewriter.create<linalg::YieldOp>(loc, ifOp->getResult(0));
275+
b.create<linalg::YieldOp>(loc, ifOp->getResult(0));
286276
}
287277
});
288278

@@ -292,7 +282,7 @@ struct GatherConverter : public OpConversionPattern<tts::GatherOp> {
292282
}
293283
};
294284

295-
// Lowering an unstructured store op (scatter) into an affine loop nest
285+
// Lowering an unstructured store op (scatter) into a linalg.generic op.
296286
struct ScatterConverter : public OpConversionPattern<tts::ScatterOp> {
297287
using OpConversionPattern<tts::ScatterOp>::OpConversionPattern;
298288

@@ -309,57 +299,81 @@ struct ScatterConverter : public OpConversionPattern<tts::ScatterOp> {
309299

310300
auto ptr = adaptor.getPtr();
311301
auto offsetTensor = adaptor.getOffset();
302+
auto valueTensor = adaptor.getValue();
312303
auto offsetType = dyn_cast<ShapedType>(offsetTensor.getType());
313304

314-
// This must be a scalar store, skip processing
305+
// This must be a scalar store, skip processing.
315306
if (!offsetType) {
316307
return failure();
317308
}
318309

319-
auto resultType =
320-
dyn_cast<RankedTensorType>(scatterOp.getValue().getType());
310+
auto valueType = dyn_cast<RankedTensorType>(scatterOp.getValue().getType());
321311

322-
auto storeMemref = rewriter.create<memref::CastOp>(
323-
loc,
324-
MemRefType::get({ShapedType::kDynamic}, resultType.getElementType()),
325-
ptr);
326-
327-
auto ip = rewriter.saveInsertionPoint();
328-
329-
SmallVector<Value> ivs;
330-
for (auto dim : resultType.getShape()) {
331-
auto ub =
332-
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(dim));
333-
auto forOp = rewriter.create<affine::AffineForOp>(loc, 0, dim);
334-
ivs.push_back(forOp.getInductionVar());
335-
rewriter.setInsertionPointToStart(forOp.getBody());
336-
}
312+
// Treat the base pointer (memref) as 1D because the offsets are all
313+
// relative to a single base pointer (already collapsed).
314+
auto baseMemref =
315+
rewriter
316+
.create<memref::CastOp>(loc,
317+
MemRefType::get({ShapedType::kDynamic},
318+
valueType.getElementType()),
319+
ptr)
320+
.getResult();
321+
322+
// The linalg.generic op should have the following inputs:
323+
// - the offset tensor.
324+
// - the value tensor.
325+
// - an optional mask tensor if the scatter op contains mask.
326+
SmallVector<Value> inputs{offsetTensor, valueTensor};
337327

338328
if (scatterOp.getMask()) {
339-
// Mask case, only store the value if the mask value at `ivs` is truthy
340-
auto maskValue =
341-
rewriter.create<tensor::ExtractOp>(loc, scatterOp.getMask(), ivs);
329+
inputs.push_back(scatterOp.getMask());
330+
}
342331

343-
auto ifOp = rewriter.create<scf::IfOp>(loc, maskValue,
344-
false /* withElseRegion */);
332+
// Affine maps for the inputs.
333+
SmallVector<AffineMap> affineMaps(
334+
inputs.size(), rewriter.getMultiDimIdentityMap(valueType.getRank()));
345335

346-
rewriter.setInsertionPointToStart(
347-
&ifOp.getThenRegion().getBlocks().front());
348-
}
336+
// All iterator types are parallel.
337+
SmallVector<utils::IteratorType> iteratorTypes(
338+
valueType.getRank(), utils::IteratorType::parallel);
349339

350-
// Generate ops to store the value at each index. Note that with masking,
351-
// these ops are created in the `if` block generated above.
352-
auto offsetValue =
353-
rewriter.create<tensor::ExtractOp>(loc, offsetTensor, ivs);
354-
auto storeValue =
355-
rewriter.create<tensor::ExtractOp>(loc, scatterOp.getValue(), ivs);
356-
Value storeIndex = rewriter.create<arith::IndexCastOp>(
357-
loc, rewriter.getIndexType(), offsetValue);
358-
rewriter.create<memref::StoreOp>(loc, storeValue, storeMemref, storeIndex);
340+
rewriter.setInsertionPoint(scatterOp);
341+
342+
auto genericOp = rewriter.create<linalg::GenericOp>(
343+
loc, TypeRange{}, inputs, ValueRange{}, affineMaps, iteratorTypes,
344+
[&](OpBuilder &b, Location loc, ValueRange args) {
345+
auto storeValueAtIndex = [baseMemref](OpBuilder &b, Location loc,
346+
Value index, Value value) {
347+
Value index0 =
348+
b.create<arith::IndexCastOp>(loc, b.getIndexType(), index);
349+
350+
b.create<memref::StoreOp>(loc, value, baseMemref,
351+
ValueRange{index0});
352+
};
353+
354+
auto offset = args[0];
355+
auto value = args[1];
356+
357+
if (!scatterOp.getMask()) {
358+
// If there is no mask, simply insert the current value to the
359+
// base memref using its offset.
360+
storeValueAtIndex(b, loc, offset, value);
361+
} else {
362+
// If the mask value is truthy, insert the current value to the
363+
// the base memref using its offset. Otherwise, noop.
364+
auto mask = args[2];
365+
auto ifOp =
366+
b.create<scf::IfOp>(loc, mask, [&](OpBuilder &b, Location loc) {
367+
storeValueAtIndex(b, loc, offset, value);
368+
b.create<scf::YieldOp>(loc);
369+
});
370+
}
371+
372+
b.create<linalg::YieldOp>(loc);
373+
});
359374

360-
// Finalize
361375
rewriter.eraseOp(scatterOp);
362-
rewriter.restoreInsertionPoint(ip);
376+
363377
return success();
364378
}
365379
};

test/Conversion/StructuredToMemref/convert_1d_elemwise_arith_binary.mlir

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ module {
3232
// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)>
3333
// CHECK-LABEL: func.func @kernel
3434
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: memref<*xf32>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32) {
35-
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
35+
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32
36+
// CHECK-DAG: [[VAR_empty_offsets_:%.+]] = tensor.empty() : tensor<1024xi32>
37+
// CHECK-DAG: [[VAR_zero_offsets_:%.+]] = linalg.fill ins([[CST_0_]] : i32) outs([[VAR_empty_offsets_]] : tensor<1024xi32>) -> tensor<1024xi32>
3638
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1]>>
3739
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1]>>
3840
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<1024xf32>
@@ -73,9 +75,11 @@ module {
7375
// CHECK: linalg.yield [[VAR_9_5_]] : f32
7476
// CHECK: } -> tensor<1024xf32>
7577
// CHECK: [[VAR_cast_:%.+]] = memref.cast [[PARAM_2_]] : memref<*xf32> to memref<?xf32>
76-
// CHECK: affine.for [[I_0_:%.+]] = 0 to 1024 {
77-
// CHECK: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_8_]]{{.}}[[I_0_]]{{.}} : tensor<1024xf32>
78-
// CHECK: memref.store [[VAR_extracted_]], [[VAR_cast_]]{{.}}[[CST_0_]]{{.}} : memref<?xf32>
78+
// CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_zero_offsets_]], [[VAR_8_]] : tensor<1024xi32>, tensor<1024xf32>) {
79+
// CHECK: ^bb0([[IN_18_:%.+]]: i32, [[IN_19_:%.+]]: f32):
80+
// CHECK: [[VAR_10_:%.+]] = arith.index_cast [[IN_18_]] : i32 to index
81+
// CHECK: memref.store [[IN_19_]], [[VAR_cast_]]{{.}}[[VAR_10_]]{{.}} : memref<?xf32>
82+
// CHECK: linalg.yield
7983
// CHECK: }
8084
// CHECK: return
8185
// CHECK: }

test/Conversion/StructuredToMemref/convert_1d_elemwise_arith_ternary.mlir

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ module {
3131
// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)>
3232
// CHECK-LABEL: func.func @kernel
3333
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xi1>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: memref<*xf32>, [[PARAM_3_:%.+]]: memref<*xf32>, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32) {
34-
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
34+
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32
35+
// CHECK-DAG: [[VAR_empty_offsets_:%.+]] = tensor.empty() : tensor<1024xi32>
36+
// CHECK-DAG: [[VAR_zero_offsets_:%.+]] = linalg.fill ins([[CST_0_]] : i32) outs([[VAR_empty_offsets_]] : tensor<1024xi32>) -> tensor<1024xi32>
3537
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [1024], strides: [1] : memref<*xi1> to memref<1024xi1, strided<[1]>>
3638
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1]>>
3739
// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: [0], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1]>>
@@ -50,9 +52,11 @@ module {
5052
// CHECK: linalg.yield [[VAR_4_]] : f32
5153
// CHECK: } -> tensor<1024xf32>
5254
// CHECK: [[VAR_cast_:%.+]] = memref.cast [[PARAM_3_]] : memref<*xf32> to memref<?xf32>
53-
// CHECK: affine.for [[I_0_:%.+]] = 0 to 1024 {
54-
// CHECK: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_3_]]{{.}}[[I_0_]]{{.}} : tensor<1024xf32>
55-
// CHECK: memref.store [[VAR_extracted_]], [[VAR_cast_]]{{.}}[[CST_0_]]{{.}} : memref<?xf32>
55+
// CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_zero_offsets_]], [[VAR_3_]] : tensor<1024xi32>, tensor<1024xf32>) {
56+
// CHECK: ^bb0([[IN_3_:%.+]]: i32, [[IN_4_:%.+]]: f32):
57+
// CHECK: [[VAR_5_:%.+]] = arith.index_cast [[IN_3_]] : i32 to index
58+
// CHECK: memref.store [[IN_4_]], [[VAR_cast_]]{{.}}[[VAR_5_]]{{.}} : memref<?xf32>
59+
// CHECK: linalg.yield
5660
// CHECK: }
5761
// CHECK: return
5862
// CHECK: }

0 commit comments

Comments
 (0)