Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions external/llvm-project/llvm/lib/Transforms/Scalar/SROA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2961,6 +2961,7 @@ class AllocaSliceRewriter : public InstVisitor<AllocaSliceRewriter, bool> {
isa<FixedVectorType>(NewAI.getAllocatedType())
? cast<FixedVectorType>(NewAI.getAllocatedType())->getElementType()
: Type::getInt8Ty(NewAI.getContext());
unsigned AllocatedEltTySize = DL.getTypeSizeInBits(AllocatedEltTy);

// Helper to check if a type is
// 1. A fixed vector type
Expand Down Expand Up @@ -2991,10 +2992,17 @@ class AllocaSliceRewriter : public InstVisitor<AllocaSliceRewriter, bool> {
// Do not handle the case if
// 1. The store does not meet the conditions in the helper function
// 2. The store is volatile
// 3. The total store size is not a multiple of the allocated element
// type size
if (!IsTypeValidForTreeStructuredMerge(
SI->getValueOperand()->getType()) ||
SI->isVolatile())
return std::nullopt;
auto *VecTy = cast<FixedVectorType>(SI->getValueOperand()->getType());
unsigned NumElts = VecTy->getNumElements();
unsigned EltSize = DL.getTypeSizeInBits(VecTy->getElementType());
if (NumElts * EltSize % AllocatedEltTySize != 0)
return std::nullopt;
StoreInfos.emplace_back(SI, S.beginOffset(), S.endOffset(),
SI->getValueOperand());
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,4 +219,18 @@ entry:

}

define <1 x i32> @test_store_value_size_not_multiple_of_allocated_element_type_size(<1 x i16> %a, <1 x i16> %b) {
entry:
%alloca = alloca [2 x i16]

%ptr0 = getelementptr inbounds [2 x i16], ptr %alloca, i32 0, i32 0
store <1 x i16> %a, ptr %ptr0

%ptr1 = getelementptr inbounds [2 x i16], ptr %alloca, i32 0, i32 1
store <1 x i16> %b, ptr %ptr1

%result = load <1 x i32>, ptr %alloca
ret <1 x i32> %result
}

declare void @llvm.memset.p0.i64(ptr nocapture writeonly, i8, i64, i1 immarg)
4 changes: 4 additions & 0 deletions mlir/test/fusion/pr-e2e/lit.local.cfg
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
if not config.enable_rock_driver_pr_e2e_test or config.no_AMD_GPU:
config.unsupported = True

if not config.arch.startswith("gfx12"):
config.excludes = ['mixr-tadd-tadd-quant-dot.mlir']

15 changes: 15 additions & 0 deletions mlir/test/fusion/pr-e2e/mixr-tadd-tadd-quant-dot.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// RUN: sed s/##TOKEN_ARCH##/%arch/g %s | rocmlir-driver -kernel-pipeline migraphx,highlevel | rocmlir-gen -ph -print-results -rand none - | rocmlir-driver -arch %arch -c | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext --entry-point-result=void | FileCheck %s

module {
// CHECK: [28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28]
func.func @mlir_add(%arg0: !migraphx.shaped<3x2x7x2xf8E4M3FN, 28x14x2x1>,
%arg1: !migraphx.shaped<3x2x5x7xf8E4M3FN, 70x35x7x1>) -> !migraphx.shaped<3x2x2x5xf32, 20x10x5x1> attributes {arch = "##TOKEN_ARCH##", kernel = "mixr"} {
%0 = migraphx.transpose %arg0 {permutation = [0, 1, 3, 2]} : <3x2x7x2xf8E4M3FN, 28x14x2x1> -> <3x2x2x7xf8E4M3FN, 28x14x1x2>
%1 = migraphx.add %0, %0 : <3x2x2x7xf8E4M3FN, 28x14x1x2>, <3x2x2x7xf8E4M3FN, 28x14x1x2> -> <3x2x2x7xf8E4M3FN, 28x14x1x2>
%2 = migraphx.transpose %arg1 {permutation = [0, 1, 3, 2]} : <3x2x5x7xf8E4M3FN, 70x35x7x1> -> <3x2x7x5xf8E4M3FN, 70x35x1x7>
%3 = migraphx.add %2, %2 : <3x2x7x5xf8E4M3FN, 70x35x1x7>, <3x2x7x5xf8E4M3FN, 70x35x1x7> -> <3x2x7x5xf8E4M3FN, 70x35x1x7>
%4 = migraphx.quant_dot %1, %3 {perf_config="v3:64,32,4,32,16,16,1,1,2,1,1"} : <3x2x2x7xf8E4M3FN, 28x14x1x2>, <3x2x7x5xf8E4M3FN, 70x35x1x7> -> <3x2x2x5xf32, 20x10x5x1>
return %4 : !migraphx.shaped<3x2x2x5xf32, 20x10x5x1>
}
}