Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -773,22 +773,30 @@ static DiagnosedSilenceableFailure checkMappingSpec(
std::optional<TransformOpInterface> transformOp, scf::ForallOp forallOp,
ArrayRef<int64_t> numParallelIterations, ArrayRef<int64_t> blockOrGridSizes,
int factor, bool useLinearMapping = false) {
if (llvm::any_of(blockOrGridSizes, [](int64_t i) { return i <= 0; })) {
return definiteFailureHelper(transformOp, forallOp,
"block/grid sizes must be strictly positive");
}
if (!useLinearMapping && blockOrGridSizes.front() % factor != 0) {
auto diag = definiteFailureHelper(
transformOp, forallOp,
Twine("3-D mapping: size of threadIdx.x must be a multiple of ") +
Twine(factor));
return diag;
}
if (computeProduct(numParallelIterations) * factor >
computeProduct(blockOrGridSizes)) {
bool hasZeroParallelIteration =
llvm::any_of(numParallelIterations, [](int64_t i) { return i == 0; });
int64_t required = hasZeroParallelIteration
? 0
: computeProduct(numParallelIterations) * factor;
int64_t available = computeProduct(blockOrGridSizes);
if (required > available) {
auto diag = definiteFailureHelper(
transformOp, forallOp,
Twine("the number of required parallel resources (blocks or "
"threads) ") +
Twine(computeProduct(numParallelIterations) * factor) +
" overflows the number of available resources " +
Twine(computeProduct(blockOrGridSizes)));
Twine(required) + " overflows the number of available resources " +
Twine(available));
return diag;
}
return DiagnosedSilenceableFailure::success();
Expand Down
8 changes: 5 additions & 3 deletions mlir/lib/Dialect/Utils/StaticValueUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,14 @@ std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
std::optional<SmallVector<int64_t>>
getConstantIntValues(ArrayRef<OpFoldResult> ofrs) {
bool failed = false;
SmallVector<int64_t> res = llvm::map_to_vector(ofrs, [&](OpFoldResult ofr) {
SmallVector<int64_t> res;
res.reserve(ofrs.size());
for (OpFoldResult ofr : ofrs) {
auto cv = getConstantIntValue(ofr);
if (!cv.has_value())
failed = true;
return cv.value_or(0);
});
res.push_back(cv.value_or(0));
}
if (failed)
return std::nullopt;
return res;
Expand Down
30 changes: 30 additions & 0 deletions mlir/test/Dialect/GPU/transform-gpu-failing.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -512,3 +512,33 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}

// -----

func.func @map_nested_forall_to_threads_invalid_block_dims(%x: memref<2 x 32 x f32>, %y: memref<2 x 32 x f32>, %t: memref<32 x f32>, %alpha : f32, %stream : !gpu.async.token) -> memref<2 x 32 x f32> {
%one = arith.constant 1 : index
%c900 = arith.constant 900 : index
%c7 = arith.constant 7 : index
%name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one)
threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one)
{
scf.forall (%i, %j) in (%c7, %c900) {
%4 = memref.load %x[%i, %j] : memref<2 x 32 x f32>
%5 = memref.load %y[%i, %j] : memref<2 x 32 x f32>
%6 = math.fma %alpha, %4, %5 : f32
memref.store %6, %y[%i, %j] : memref<2 x 32 x f32>
} { mapping = [#gpu.thread<y>, #gpu.thread<x>] }
gpu.terminator
}
return %y : memref<2 x 32 x f32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!transform.any_op) -> !transform.any_op
// expected-error @below {{block/grid sizes must be strictly positive}}
transform.gpu.map_nested_forall_to_threads %funcop block_dims = [128, 0, 1] : (!transform.any_op) -> !transform.any_op
transform.yield
}
}

Loading