Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 106 additions & 92 deletions lib/Conversion/UnstructuredToMemref/UnstructuredToMemrefPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ class PtrToUnrankedMemrefConverter : public TypeConverter {
});
addTargetMaterialization([&](OpBuilder &builder,
UnrankedMemRefType resultType,
ValueRange inputs,
Location loc) -> Value {
ValueRange inputs, Location loc) -> Value {
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});
Expand Down Expand Up @@ -158,7 +157,7 @@ struct ScalarStoreConverter : public OpConversionPattern<tts::ScatterOp> {
}
};

// Lowering an unstructured load op (gather) into a linalg.generic op
// Lowering an unstructured load op (gather) into a linalg.generic op.
struct GatherConverter : public OpConversionPattern<tts::GatherOp> {
using OpConversionPattern<tts::GatherOp>::OpConversionPattern;

Expand All @@ -171,118 +170,109 @@ struct GatherConverter : public OpConversionPattern<tts::GatherOp> {
LogicalResult
matchAndRewrite(tts::GatherOp gatherOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto loc = gatherOp->getLoc();

auto ptr = adaptor.getPtr();
auto offsetTensor = adaptor.getOffset();
auto offsetType = dyn_cast<ShapedType>(offsetTensor.getType());

// This must be a scalar load, skip processing
// This must be a scalar load, skip processing.
if (!offsetType) {
return failure();
}

auto loadResultType =
auto resultType =
dyn_cast<RankedTensorType>(gatherOp.getResult().getType());

// Treat the base pointer (memref) as 1D because the offsets are all
// relative to a single base pointer (already collapsed).
auto baseMemref = rewriter.create<memref::CastOp>(
loc,
MemRefType::get({ShapedType::kDynamic},
loadResultType.getElementType()),
ptr);
auto baseMemref = rewriter
.create<memref::CastOp>(
loc,
MemRefType::get({ShapedType::kDynamic},
resultType.getElementType()),
ptr)
.getResult();

auto baseTensor =
rewriter
.create<bufferization::ToTensorOp>(
loc,
RankedTensorType::get(
SmallVector<int64_t>(1, ShapedType::kDynamic),
loadResultType.getElementType()),
resultType.getElementType()),
baseMemref, true /* restrict */, false /* writable */)
.getResult();

// The linalg.generic op should have the following inputs:
// - the offset tensor
// - an optional mask tensor if the load op contains mask
// - the offset tensor.
// - an optional mask tensor if the gather op contains mask.
SmallVector<Value> inputs{offsetTensor};

if (gatherOp.getMask()) {
inputs.push_back(gatherOp.getMask());
}

auto emptyTensor =
rewriter
.create<tensor::EmptyOp>(loc, loadResultType.getShape(),
loadResultType.getElementType())
.getResult();
auto emptyTensor = rewriter
.create<tensor::EmptyOp>(loc, resultType.getShape(),
resultType.getElementType())
.getResult();

// Affine maps for the inputs and output
// If no mask is used, 2 affine maps are generated; one for the input offset
// tensor, the other for the output tensor.
// If mask is used, the first 2 maps are for the offset and mask tensors
// while the last map is for the output tensor.
// Affine maps for the inputs and one additional output.
SmallVector<AffineMap> affineMaps(
gatherOp.getMask() ? 3 : 2,
rewriter.getMultiDimIdentityMap(loadResultType.getRank()));
inputs.size() + 1,
rewriter.getMultiDimIdentityMap(resultType.getRank()));

// All iterator types are parallel.
SmallVector<utils::IteratorType> iteratorTypes(
resultType.getRank(), utils::IteratorType::parallel);

auto genericOp = rewriter.create<linalg::GenericOp>(
loc, SmallVector<Type>({loadResultType}), inputs,
ValueRange{emptyTensor}, affineMaps,
SmallVector<utils::IteratorType>(loadResultType.getRank(),
utils::IteratorType::parallel),
[&](OpBuilder &b, Location loc, ValueRange args) {
auto getValueAtIndex = [baseTensor](Value indexValue, Location loc,
OpBuilder &b) -> Value {
loc, TypeRange{resultType}, inputs, ValueRange{emptyTensor}, affineMaps,
iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
auto getValueAtIndex = [baseTensor](OpBuilder &b, Location loc,
Value index) -> Value {
Value index0 =
b.create<arith::IndexCastOp>(loc, b.getIndexType(), indexValue);
b.create<arith::IndexCastOp>(loc, b.getIndexType(), index);

return b.create<tensor::ExtractOp>(loc, baseTensor,
ValueRange{index0});
};

auto offset = args[0];

if (!gatherOp.getMask()) {
// If there is no mask, simply extract the current element from the
// base tensor and use it as the yield value.
auto loadValue = getValueAtIndex(args[0], loc, rewriter);
rewriter.create<linalg::YieldOp>(loc, loadValue);
auto loadValue = getValueAtIndex(b, loc, offset);
b.create<linalg::YieldOp>(loc, loadValue);
} else {
// If the mask value is truthy, the current element is loaded from
// the base tensor using its offset. Otherwise, if `other` is
// present, yield `other`. If `other` is not present, a default
// value of 0 is used.
auto mask = args[1];
auto ifOp = rewriter.create<scf::IfOp>(
auto ifOp = b.create<scf::IfOp>(
loc, mask,
[&](OpBuilder &b, Location loc) {
// Truthy case, load from the index
auto loadValue = getValueAtIndex(args[0], loc, b);
b.create<scf::YieldOp>(loc, loadValue);
// Truthy case, load from the index.
auto value = getValueAtIndex(b, loc, offset);
b.create<scf::YieldOp>(loc, value);
},
[&](OpBuilder &b, Location loc) {
// Falsy case, yield `other` or 0 as the default value
// Falsy case, yield `other` or 0 as the default value.
if (gatherOp.getOther()) {
b.create<scf::YieldOp>(loc, gatherOp.getOther());
} else {
auto elemType = baseTensor.getType().getElementType();
Value extract;
if (isa<IntegerType>(elemType)) {
extract = rewriter.create<arith::ConstantOp>(
loc, b.getIntegerAttr(elemType, 0));
} else if (isa<FloatType>(elemType)) {
extract = rewriter.create<arith::ConstantOp>(
loc, b.getFloatAttr(elemType, 0));
} else {
elemType.dump();
llvm_unreachable("unexpected type");
}
auto elemType = resultType.getElementType();
auto zeroAttr = b.getZeroAttr(elemType);
assert(zeroAttr && "unexpected element type");
Value extract = b.create<arith::ConstantOp>(loc, zeroAttr);
b.create<scf::YieldOp>(loc, extract);
}
});

rewriter.create<linalg::YieldOp>(loc, ifOp->getResult(0));
b.create<linalg::YieldOp>(loc, ifOp->getResult(0));
}
});

Expand All @@ -292,7 +282,7 @@ struct GatherConverter : public OpConversionPattern<tts::GatherOp> {
}
};

// Lowering an unstructured store op (scatter) into an affine loop nest
// Lowering an unstructured store op (scatter) into a linalg.generic op.
struct ScatterConverter : public OpConversionPattern<tts::ScatterOp> {
using OpConversionPattern<tts::ScatterOp>::OpConversionPattern;

Expand All @@ -309,57 +299,81 @@ struct ScatterConverter : public OpConversionPattern<tts::ScatterOp> {

auto ptr = adaptor.getPtr();
auto offsetTensor = adaptor.getOffset();
auto valueTensor = adaptor.getValue();
auto offsetType = dyn_cast<ShapedType>(offsetTensor.getType());

// This must be a scalar store, skip processing
// This must be a scalar store, skip processing.
if (!offsetType) {
return failure();
}

auto resultType =
dyn_cast<RankedTensorType>(scatterOp.getValue().getType());
auto valueType = dyn_cast<RankedTensorType>(scatterOp.getValue().getType());

auto storeMemref = rewriter.create<memref::CastOp>(
loc,
MemRefType::get({ShapedType::kDynamic}, resultType.getElementType()),
ptr);

auto ip = rewriter.saveInsertionPoint();

SmallVector<Value> ivs;
for (auto dim : resultType.getShape()) {
auto ub =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(dim));
auto forOp = rewriter.create<affine::AffineForOp>(loc, 0, dim);
ivs.push_back(forOp.getInductionVar());
rewriter.setInsertionPointToStart(forOp.getBody());
}
// Treat the base pointer (memref) as 1D because the offsets are all
// relative to a single base pointer (already collapsed).
auto baseMemref =
rewriter
.create<memref::CastOp>(loc,
MemRefType::get({ShapedType::kDynamic},
valueType.getElementType()),
ptr)
.getResult();

// The linalg.generic op should have the following inputs:
// - the offset tensor.
// - the value tensor.
// - an optional mask tensor if the scatter op contains mask.
SmallVector<Value> inputs{offsetTensor, valueTensor};

if (scatterOp.getMask()) {
// Mask case, only store the value if the mask value at `ivs` is truthy
auto maskValue =
rewriter.create<tensor::ExtractOp>(loc, scatterOp.getMask(), ivs);
inputs.push_back(scatterOp.getMask());
}

auto ifOp = rewriter.create<scf::IfOp>(loc, maskValue,
false /* withElseRegion */);
// Affine maps for the inputs.
SmallVector<AffineMap> affineMaps(
inputs.size(), rewriter.getMultiDimIdentityMap(valueType.getRank()));

rewriter.setInsertionPointToStart(
&ifOp.getThenRegion().getBlocks().front());
}
// All iterator types are parallel.
SmallVector<utils::IteratorType> iteratorTypes(
valueType.getRank(), utils::IteratorType::parallel);

// Generate ops to store the value at each index. Note that with masking,
// these ops are created in the `if` block generated above.
auto offsetValue =
rewriter.create<tensor::ExtractOp>(loc, offsetTensor, ivs);
auto storeValue =
rewriter.create<tensor::ExtractOp>(loc, scatterOp.getValue(), ivs);
Value storeIndex = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), offsetValue);
rewriter.create<memref::StoreOp>(loc, storeValue, storeMemref, storeIndex);
rewriter.setInsertionPoint(scatterOp);

auto genericOp = rewriter.create<linalg::GenericOp>(
loc, TypeRange{}, inputs, ValueRange{}, affineMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
auto storeValueAtIndex = [baseMemref](OpBuilder &b, Location loc,
Value index, Value value) {
Value index0 =
b.create<arith::IndexCastOp>(loc, b.getIndexType(), index);

b.create<memref::StoreOp>(loc, value, baseMemref,
ValueRange{index0});
};

auto offset = args[0];
auto value = args[1];

if (!scatterOp.getMask()) {
// If there is no mask, simply insert the current value to the
// base memref using its offset.
storeValueAtIndex(b, loc, offset, value);
} else {
// If the mask value is truthy, insert the current value to the
// the base memref using its offset. Otherwise, noop.
auto mask = args[2];
auto ifOp =
b.create<scf::IfOp>(loc, mask, [&](OpBuilder &b, Location loc) {
storeValueAtIndex(b, loc, offset, value);
b.create<scf::YieldOp>(loc);
});
}

b.create<linalg::YieldOp>(loc);
});

// Finalize
rewriter.eraseOp(scatterOp);
rewriter.restoreInsertionPoint(ip);

return success();
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ module {
// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func.func @kernel
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xf32>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: memref<*xf32>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32) {
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32
// CHECK-DAG: [[VAR_empty_offsets_:%.+]] = tensor.empty() : tensor<1024xi32>
// CHECK-DAG: [[VAR_zero_offsets_:%.+]] = linalg.fill ins([[CST_0_]] : i32) outs([[VAR_empty_offsets_]] : tensor<1024xi32>) -> tensor<1024xi32>
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1]>>
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1]>>
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() : memref<1024xf32>
Expand Down Expand Up @@ -73,9 +75,11 @@ module {
// CHECK: linalg.yield [[VAR_9_5_]] : f32
// CHECK: } -> tensor<1024xf32>
// CHECK: [[VAR_cast_:%.+]] = memref.cast [[PARAM_2_]] : memref<*xf32> to memref<?xf32>
// CHECK: affine.for [[I_0_:%.+]] = 0 to 1024 {
// CHECK: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_8_]]{{.}}[[I_0_]]{{.}} : tensor<1024xf32>
// CHECK: memref.store [[VAR_extracted_]], [[VAR_cast_]]{{.}}[[CST_0_]]{{.}} : memref<?xf32>
// CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_zero_offsets_]], [[VAR_8_]] : tensor<1024xi32>, tensor<1024xf32>) {
// CHECK: ^bb0([[IN_18_:%.+]]: i32, [[IN_19_:%.+]]: f32):
// CHECK: [[VAR_10_:%.+]] = arith.index_cast [[IN_18_]] : i32 to index
// CHECK: memref.store [[IN_19_]], [[VAR_cast_]]{{.}}[[VAR_10_]]{{.}} : memref<?xf32>
// CHECK: linalg.yield
// CHECK: }
// CHECK: return
// CHECK: }
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ module {
// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func.func @kernel
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<*xi1>, [[PARAM_1_:%.+]]: memref<*xf32>, [[PARAM_2_:%.+]]: memref<*xf32>, [[PARAM_3_:%.+]]: memref<*xf32>, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32, [[PARAM_8_:%.+]]: i32, [[PARAM_9_:%.+]]: i32) {
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i32
// CHECK-DAG: [[VAR_empty_offsets_:%.+]] = tensor.empty() : tensor<1024xi32>
// CHECK-DAG: [[VAR_zero_offsets_:%.+]] = linalg.fill ins([[CST_0_]] : i32) outs([[VAR_empty_offsets_]] : tensor<1024xi32>) -> tensor<1024xi32>
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [1024], strides: [1] : memref<*xi1> to memref<1024xi1, strided<[1]>>
// CHECK-DAG: [[VAR_reinterpret_cast_0_:%.+]] = memref.reinterpret_cast [[PARAM_1_]] to offset: [0], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1]>>
// CHECK-DAG: [[VAR_reinterpret_cast_1_:%.+]] = memref.reinterpret_cast [[PARAM_2_]] to offset: [0], sizes: [1024], strides: [1] : memref<*xf32> to memref<1024xf32, strided<[1]>>
Expand All @@ -50,9 +52,11 @@ module {
// CHECK: linalg.yield [[VAR_4_]] : f32
// CHECK: } -> tensor<1024xf32>
// CHECK: [[VAR_cast_:%.+]] = memref.cast [[PARAM_3_]] : memref<*xf32> to memref<?xf32>
// CHECK: affine.for [[I_0_:%.+]] = 0 to 1024 {
// CHECK: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_3_]]{{.}}[[I_0_]]{{.}} : tensor<1024xf32>
// CHECK: memref.store [[VAR_extracted_]], [[VAR_cast_]]{{.}}[[CST_0_]]{{.}} : memref<?xf32>
// CHECK: linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins([[VAR_zero_offsets_]], [[VAR_3_]] : tensor<1024xi32>, tensor<1024xf32>) {
// CHECK: ^bb0([[IN_3_:%.+]]: i32, [[IN_4_:%.+]]: f32):
// CHECK: [[VAR_5_:%.+]] = arith.index_cast [[IN_3_]] : i32 to index
// CHECK: memref.store [[IN_4_]], [[VAR_cast_]]{{.}}[[VAR_5_]]{{.}} : memref<?xf32>
// CHECK: linalg.yield
// CHECK: }
// CHECK: return
// CHECK: }
Loading