Skip to content

Commit 3d7e854

Browse files
authored
Merge pull request #2044 from ROCm/justinr-sroa-7.1
[7.1][EXTERNAL][SROA] Add Stored Value Size Check for Tree-Structured Merge
2 parents 33b0fc5 + f93de3d commit 3d7e854

File tree

4 files changed

+41
-0
lines changed

4 files changed

+41
-0
lines changed

external/llvm-project/llvm/lib/Transforms/Scalar/SROA.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2961,6 +2961,7 @@ class AllocaSliceRewriter : public InstVisitor<AllocaSliceRewriter, bool> {
29612961
isa<FixedVectorType>(NewAI.getAllocatedType())
29622962
? cast<FixedVectorType>(NewAI.getAllocatedType())->getElementType()
29632963
: Type::getInt8Ty(NewAI.getContext());
2964+
unsigned AllocatedEltTySize = DL.getTypeSizeInBits(AllocatedEltTy);
29642965

29652966
// Helper to check if a type is
29662967
// 1. A fixed vector type
@@ -2991,10 +2992,17 @@ class AllocaSliceRewriter : public InstVisitor<AllocaSliceRewriter, bool> {
29912992
// Do not handle the case if
29922993
// 1. The store does not meet the conditions in the helper function
29932994
// 2. The store is volatile
2995+
// 3. The total store size is not a multiple of the allocated element
2996+
// type size
29942997
if (!IsTypeValidForTreeStructuredMerge(
29952998
SI->getValueOperand()->getType()) ||
29962999
SI->isVolatile())
29973000
return std::nullopt;
3001+
auto *VecTy = cast<FixedVectorType>(SI->getValueOperand()->getType());
3002+
unsigned NumElts = VecTy->getNumElements();
3003+
unsigned EltSize = DL.getTypeSizeInBits(VecTy->getElementType());
3004+
if (NumElts * EltSize % AllocatedEltTySize != 0)
3005+
return std::nullopt;
29983006
StoreInfos.emplace_back(SI, S.beginOffset(), S.endOffset(),
29993007
SI->getValueOperand());
30003008
} else {

external/llvm-project/llvm/test/Transforms/SROA/vector-promotion-cannot-tree-structure-merge.ll

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,4 +219,18 @@ entry:
219219

220220
}
221221

222+
define <1 x i32> @test_store_value_size_not_multiple_of_allocated_element_type_size(<1 x i16> %a, <1 x i16> %b) {
223+
entry:
224+
%alloca = alloca [2 x i16]
225+
226+
%ptr0 = getelementptr inbounds [2 x i16], ptr %alloca, i32 0, i32 0
227+
store <1 x i16> %a, ptr %ptr0
228+
229+
%ptr1 = getelementptr inbounds [2 x i16], ptr %alloca, i32 0, i32 1
230+
store <1 x i16> %b, ptr %ptr1
231+
232+
%result = load <1 x i32>, ptr %alloca
233+
ret <1 x i32> %result
234+
}
235+
222236
declare void @llvm.memset.p0.i64(ptr nocapture writeonly, i8, i64, i1 immarg)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,6 @@
11
if not config.enable_rock_driver_pr_e2e_test or config.no_AMD_GPU:
22
config.unsupported = True
3+
4+
if not config.arch.startswith("gfx12"):
5+
config.excludes = ['mixr-tadd-tadd-quant-dot.mlir']
6+
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// RUN: sed s/##TOKEN_ARCH##/%arch/g %s | rocmlir-driver -kernel-pipeline migraphx,highlevel | rocmlir-gen -ph -print-results -rand none - | rocmlir-driver -arch %arch -c | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext --entry-point-result=void | FileCheck %s
2+
3+
module {
4+
// CHECK: [28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28]
5+
func.func @mlir_add(%arg0: !migraphx.shaped<3x2x7x2xf8E4M3FN, 28x14x2x1>,
6+
%arg1: !migraphx.shaped<3x2x5x7xf8E4M3FN, 70x35x7x1>) -> !migraphx.shaped<3x2x2x5xf32, 20x10x5x1> attributes {arch = "##TOKEN_ARCH##", kernel = "mixr"} {
7+
%0 = migraphx.transpose %arg0 {permutation = [0, 1, 3, 2]} : <3x2x7x2xf8E4M3FN, 28x14x2x1> -> <3x2x2x7xf8E4M3FN, 28x14x1x2>
8+
%1 = migraphx.add %0, %0 : <3x2x2x7xf8E4M3FN, 28x14x1x2>, <3x2x2x7xf8E4M3FN, 28x14x1x2> -> <3x2x2x7xf8E4M3FN, 28x14x1x2>
9+
%2 = migraphx.transpose %arg1 {permutation = [0, 1, 3, 2]} : <3x2x5x7xf8E4M3FN, 70x35x7x1> -> <3x2x7x5xf8E4M3FN, 70x35x1x7>
10+
%3 = migraphx.add %2, %2 : <3x2x7x5xf8E4M3FN, 70x35x1x7>, <3x2x7x5xf8E4M3FN, 70x35x1x7> -> <3x2x7x5xf8E4M3FN, 70x35x1x7>
11+
%4 = migraphx.quant_dot %1, %3 {perf_config="v3:64,32,4,32,16,16,1,1,2,1,1"} : <3x2x2x7xf8E4M3FN, 28x14x1x2>, <3x2x7x5xf8E4M3FN, 70x35x1x7> -> <3x2x2x5xf32, 20x10x5x1>
12+
return %4 : !migraphx.shaped<3x2x2x5xf32, 20x10x5x1>
13+
}
14+
}
15+

0 commit comments

Comments
 (0)