Skip to content

Commit 4709018

Browse files
authored
[MLIR][Spirv] Don't lower tensors that can't be represented by an ArrayType (#171002)
I noticed this because of llvm/llvm-project#159738, though it was only caught by his fuzzer because it wrapped to 0. Also, is there a reason for the usage of `unsigned` for sizes in spirv types? I believe most of the builtin types use `int64_t` for sizes, so it may make sense to do the same for spirv.
1 parent 371da58 commit 4709018

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,11 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
502502
<< type << " illegal: cannot handle zero-element tensors\n");
503503
return nullptr;
504504
}
505+
if (arrayElemCount > std::numeric_limits<unsigned>::max()) {
506+
LLVM_DEBUG(llvm::dbgs()
507+
<< type << " illegal: cannot fit tensor into target type\n");
508+
return nullptr;
509+
}
505510

506511
Type arrayElemType = convertScalarType(targetEnv, options, scalarType);
507512
if (!arrayElemType)

mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,12 @@ func.func @tensor_2d_empty() -> () {
7979
%x = arith.constant dense<> : tensor<2x0xi32>
8080
return
8181
}
82+
83+
// Tensors with more than UINT32_MAX elements cannnot fit in a spirv.array.
84+
// Test that they are not lowered.
85+
// CHECK-LABEL: func @very_large_tensor
86+
// CHECK-NEXT: arith.constant dense<1>
87+
func.func @very_large_tensor() -> () {
88+
%x = arith.constant dense<1> : tensor<4294967296xi32>
89+
return
90+
}

0 commit comments

Comments
 (0)