Skip to content

Commit df71000

Browse files
committed
[mlir][spirv] Convert linalg.generic for reduction to SPIR-V ops
This commit adds a pattern to lower linalg.generic for reduction to spv.GroupNonUniform* ops. Right now this only supports integer reduction on 1-D input memref. Shader entry point ABI is queried to make sure that the input memref's shape matches the local workgroup's invocation configuration. This makes sure that the workload fits in one local workgroup so that we can leverage SPIR-V group non-uniform operations. linglg.generic is a structured op that preserves the right level of information. It is easier to recognize reduction at this level than performing analysis on loops. This commit also exposes `getElementPtr` in SPIRVLowering.h given that it's a generally useful utility function. Differential Revision: https://reviews.llvm.org/D73437
1 parent 0bb60e2 commit df71000

File tree

20 files changed

+682
-54
lines changed

20 files changed

+682
-54
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
//===- LinalgToSPIRV.h - Linalg to SPIR-V dialect conversion ----*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file provides patterns for Linalg to SPIR-V dialect conversion.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_CONVERSION_LINALGTOSPIRV_LINALGTOSPIRV_H
14+
#define MLIR_CONVERSION_LINALGTOSPIRV_LINALGTOSPIRV_H
15+
16+
namespace mlir {
17+
class MLIRContext;
18+
class OwningRewritePatternList;
19+
class SPIRVTypeConverter;
20+
21+
/// Appends to a pattern list additional patterns for translating Linalg ops to
22+
/// SPIR-V ops.
23+
void populateLinalgToSPIRVPatterns(MLIRContext *context,
24+
SPIRVTypeConverter &typeConverter,
25+
OwningRewritePatternList &patterns);
26+
27+
} // namespace mlir
28+
29+
#endif // MLIR_CONVERSION_LINALGTOSPIRV_LINALGTOSPIRV_H
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
//===- LinalgToSPIRVPass.h - Linalg to SPIR-V conversion pass --*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file provides a pass for Linalg to SPIR-V dialect conversion.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_CONVERSION_STANDARDTOSPIRV_LINALGTOSPIRVPASS_H
14+
#define MLIR_CONVERSION_STANDARDTOSPIRV_LINALGTOSPIRVPASS_H
15+
16+
#include "mlir/Pass/Pass.h"
17+
18+
namespace mlir {
19+
20+
/// Creates and returns a pass to convert Linalg ops to SPIR-V ops.
21+
std::unique_ptr<OpPassBase<ModuleOp>> createLinalgToSPIRVPass();
22+
23+
} // namespace mlir
24+
25+
#endif // MLIR_CONVERSION_STANDARDTOSPIRV_LINALGTOSPIRVPASS_H

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
5454
"Query the number of loops within the current operation.",
5555
"unsigned", "getNumLoops">,
5656

57+
InterfaceMethod<
58+
[{Returns true if the current operation has only one loop and it's a
59+
reduction loop}],
60+
"unsigned", "hasSingleReductionLoop">,
61+
5762
//========================================================================//
5863
// Input arguments handling.
5964
//========================================================================//

mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,12 @@ class StructuredOpTraits
9393
cast<ConcreteType>(this->getOperation()).iterator_types());
9494
}
9595

96+
bool hasSingleReductionLoop() {
97+
auto iterators = cast<ConcreteType>(this->getOperation()).iterator_types();
98+
return iterators.size() == 1 &&
99+
getNumIterators(getReductionIteratorTypeName(), iterators);
100+
}
101+
96102
//==========================================================================//
97103
// Input arguments handling.
98104
//==========================================================================//

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,26 @@ struct FusionInfo {
2828
LinalgOp fusedProducer;
2929
};
3030

31+
/// A struct containing common matchers over linalg op's region.
32+
struct RegionMatcher {
33+
enum class BinaryOpKind {
34+
IAdd,
35+
};
36+
37+
/// Matches the given linalg op if its body is performing binary operation on
38+
/// int or float scalar values and returns the binary op kind.
39+
///
40+
/// The linalg op's region is expected to be
41+
/// ```
42+
/// {
43+
/// ^bb(%a: <scalar-type>, %b: <scalar-type>):
44+
/// %0 = <binary-op> %a, %b: <scalar-type>
45+
/// linalg.yield %0: <scalar-type>
46+
/// }
47+
/// ```
48+
static Optional<BinaryOpKind> matchAsScalarBinaryOp(GenericOp op);
49+
};
50+
3151
/// Checks whether the specific `producer` is the last write to exactly the
3252
/// whole `consumedView`. This checks structural dominance, that the dependence
3353
/// is a RAW without any interleaved write to any piece of `consumedView`.

mlir/include/mlir/Dialect/SPIRV/SPIRVAtomicOps.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class SPV_AtomicUpdateOp<string mnemonic, list<OpTrait> traits = []> :
2525
SPV_ScopeAttr:$memory_scope,
2626
SPV_MemorySemanticsAttr:$semantics
2727
);
28+
2829
let results = (outs
2930
SPV_Integer:$result
3031
);
@@ -42,9 +43,19 @@ class SPV_AtomicUpdateWithValueOp<string mnemonic, list<OpTrait> traits = []> :
4243
SPV_MemorySemanticsAttr:$semantics,
4344
SPV_Integer:$value
4445
);
46+
4547
let results = (outs
4648
SPV_Integer:$result
4749
);
50+
51+
let builders = [
52+
OpBuilder<
53+
[{Builder *builder, OperationState &state, Value pointer,
54+
::mlir::spirv::Scope scope, ::mlir::spirv::MemorySemantics memory,
55+
Value value}],
56+
[{build(builder, state, value.getType(), pointer, scope, memory, value);}]
57+
>
58+
];
4859
}
4960

5061
// -----

mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -453,14 +453,22 @@ def SPV_SelectionOp : SPV_Op<"selection", [InFunctionScope]> {
453453
let regions = (region AnyRegion:$body);
454454

455455
let extraClassDeclaration = [{
456-
// Returns the selection header block.
456+
/// Returns the selection header block.
457457
Block *getHeaderBlock();
458458

459-
// Returns the selection merge block.
459+
/// Returns the selection merge block.
460460
Block *getMergeBlock();
461461

462-
// Adds a selection merge block containing one spv._merge op.
462+
/// Adds a selection merge block containing one spv._merge op.
463463
void addMergeBlock();
464+
465+
/// Creates a spv.selection op for `if (<condition>) then { <thenBody> }`
466+
/// with `builder`. `builder`'s insertion point will remain at after the
467+
/// newly inserted spv.selection op afterwards.
468+
static SelectionOp createIfThen(
469+
Location loc, Value condition,
470+
function_ref<void(OpBuilder *builder)> thenBody,
471+
OpBuilder *builder);
464472
}];
465473

466474
let hasOpcode = 0;

mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ void populateBuiltinFuncToSPIRVPatterns(MLIRContext *context,
5858
OwningRewritePatternList &patterns);
5959

6060
namespace spirv {
61+
class AccessChainOp;
62+
6163
class SPIRVConversionTarget : public ConversionTarget {
6264
public:
6365
/// Creates a SPIR-V conversion target for the given target environment.
@@ -90,6 +92,16 @@ class SPIRVConversionTarget : public ConversionTarget {
9092
Value getBuiltinVariableValue(Operation *op, BuiltIn builtin,
9193
OpBuilder &builder);
9294

95+
/// Performs the index computation to get to the element at `indices` of the
96+
/// memory pointed to by `basePtr`, using the layout map of `baseType`.
97+
98+
// TODO(ravishankarm) : This method assumes that the `baseType` is a MemRefType
99+
// with AffineMap that has static strides. Extend to handle dynamic strides.
100+
spirv::AccessChainOp getElementPtr(SPIRVTypeConverter &typeConverter,
101+
MemRefType baseType, Value basePtr,
102+
ArrayRef<Value> indices, Location loc,
103+
OpBuilder &builder);
104+
93105
/// Sets the InterfaceVarABIAttr and EntryPointABIAttr for a function and its
94106
/// arguments.
95107
LogicalResult setABIAttrs(FuncOp funcOp, EntryPointABIAttr entryPointInfo,

mlir/include/mlir/Dialect/SPIRV/TargetAndABI.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,12 @@ TargetEnvAttr getDefaultTargetEnv(MLIRContext *context);
5454
/// target environment (SPIR-V 1.0 with Shader capability and no extra
5555
/// extensions) if not provided.
5656
TargetEnvAttr lookupTargetEnvOrDefault(Operation *op);
57+
58+
/// Queries the local workgroup size from entry point ABI on the nearest
59+
/// function-like op containing the given `op`. Returns null attribute if not
60+
/// found.
61+
DenseIntElementsAttr lookupLocalWorkGroupSize(Operation *op);
62+
5763
} // namespace spirv
5864
} // namespace mlir
5965

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_subdirectory(GPUToNVVM)
44
add_subdirectory(GPUToROCDL)
55
add_subdirectory(GPUToSPIRV)
66
add_subdirectory(LinalgToLLVM)
7+
add_subdirectory(LinalgToSPIRV)
78
add_subdirectory(LoopsToGPU)
89
add_subdirectory(LoopToStandard)
910
add_subdirectory(StandardToLLVM)

0 commit comments

Comments
 (0)