Skip to content

Commit d3a00e0

Browse files
authored
Merge pull request #2011 from ROCm/justinr-const-fold-patch
[EXTERNAL] Fix v_mov_b16_t16 index in folding pass
2 parents d531704 + c11b4e0 commit d3a00e0

File tree

5 files changed

+95
-1
lines changed

5 files changed

+95
-1
lines changed

external/llvm-project/llvm/lib/Target/AMDGPU/SIFoldOperands.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -931,7 +931,9 @@ static MachineOperand *lookUpCopyChain(const SIInstrInfo &TII,
931931
for (MachineInstr *SubDef = MRI.getVRegDef(SrcReg);
932932
SubDef && TII.isFoldableCopy(*SubDef);
933933
SubDef = MRI.getVRegDef(Sub->getReg())) {
934-
MachineOperand &SrcOp = SubDef->getOperand(1);
934+
unsigned SrcIdx = TII.getFoldableCopySrcIdx(*SubDef);
935+
MachineOperand &SrcOp = SubDef->getOperand(SrcIdx);
936+
935937
if (SrcOp.isImm())
936938
return &SrcOp;
937939
if (!SrcOp.isReg() || SrcOp.getReg().isPhysical())

external/llvm-project/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3479,6 +3479,32 @@ bool SIInstrInfo::isFoldableCopy(const MachineInstr &MI) {
34793479
}
34803480
}
34813481

3482+
unsigned SIInstrInfo::getFoldableCopySrcIdx(const MachineInstr &MI) {
3483+
switch (MI.getOpcode()) {
3484+
case AMDGPU::V_MOV_B16_t16_e32:
3485+
case AMDGPU::V_MOV_B16_t16_e64:
3486+
return 2;
3487+
case AMDGPU::V_MOV_B32_e32:
3488+
case AMDGPU::V_MOV_B32_e64:
3489+
case AMDGPU::V_MOV_B64_PSEUDO:
3490+
case AMDGPU::V_MOV_B64_e32:
3491+
case AMDGPU::V_MOV_B64_e64:
3492+
case AMDGPU::S_MOV_B32:
3493+
case AMDGPU::S_MOV_B64:
3494+
case AMDGPU::S_MOV_B64_IMM_PSEUDO:
3495+
case AMDGPU::COPY:
3496+
case AMDGPU::WWM_COPY:
3497+
case AMDGPU::V_ACCVGPR_WRITE_B32_e64:
3498+
case AMDGPU::V_ACCVGPR_READ_B32_e64:
3499+
case AMDGPU::V_ACCVGPR_MOV_B32:
3500+
case AMDGPU::AV_MOV_B32_IMM_PSEUDO:
3501+
case AMDGPU::AV_MOV_B64_IMM_PSEUDO:
3502+
return 1;
3503+
default:
3504+
llvm_unreachable("MI is not a foldable copy");
3505+
}
3506+
}
3507+
34823508
static constexpr AMDGPU::OpName ModifierOpNames[] = {
34833509
AMDGPU::OpName::src0_modifiers, AMDGPU::OpName::src1_modifiers,
34843510
AMDGPU::OpName::src2_modifiers, AMDGPU::OpName::clamp,

external/llvm-project/llvm/lib/Target/AMDGPU/SIInstrInfo.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,7 @@ class SIInstrInfo final : public AMDGPUGenInstrInfo {
433433
const MachineInstr &MIb) const override;
434434

435435
static bool isFoldableCopy(const MachineInstr &MI);
436+
static unsigned getFoldableCopySrcIdx(const MachineInstr &MI);
436437

437438
void removeModOperands(MachineInstr &MI) const;
438439

external/llvm-project/llvm/test/CodeGen/AMDGPU/true16-fold.mir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ body: |
5959
%4:vgpr_16 = COPY %3:sgpr_lo16
6060
%5:vgpr_32 = V_ALIGNBIT_B32_t16_e64 0, %0:sreg_32, 0, killed %1:sreg_32, 0, killed %4:vgpr_16, 0, 0, implicit $exec
6161
S_ENDPGM 0, implicit %5
62+
...
6263

6364
---
6465
name: fold_16bit_madmix_clamp
@@ -197,3 +198,28 @@ body: |
197198
$vgpr0 = COPY %4
198199
S_ENDPGM 0, implicit $vgpr0
199200
...
201+
202+
---
203+
name: fold_imm16_across_reg_sequence
204+
tracksRegLiveness: true
205+
registers:
206+
body: |
207+
bb.0:
208+
liveins: $vgpr0, $vgpr1, $vgpr2
209+
; CHECK-LABEL: name: fold_imm16_across_reg_sequence
210+
; CHECK: liveins: $vgpr0, $vgpr1, $vgpr2
211+
; CHECK-NEXT: {{ $}}
212+
; CHECK-NEXT: [[V_MOV_B16_t16_e64_:%[0-9]+]]:vgpr_16 = V_MOV_B16_t16_e64 0, -1, 0, implicit $exec
213+
; CHECK-NEXT: [[V_MOV_B16_t16_e64_1:%[0-9]+]]:vgpr_16 = V_MOV_B16_t16_e64 0, -1, 0, implicit $exec
214+
; CHECK-NEXT: [[REG_SEQUENCE:%[0-9]+]]:vgpr_32 = REG_SEQUENCE [[V_MOV_B16_t16_e64_]], %subreg.lo16, [[V_MOV_B16_t16_e64_1]], %subreg.hi16
215+
; CHECK-NEXT: [[V_MAX_F32_e64_:%[0-9]+]]:vgpr_32 = nofpexcept V_MAX_F32_e64 0, -1, 0, -1, 0, 0, implicit $mode, implicit $exec
216+
; CHECK-NEXT: $vgpr0 = COPY [[V_MAX_F32_e64_]]
217+
; CHECK-NEXT: S_ENDPGM 0, implicit $vgpr0
218+
%0:vgpr_16 = V_MOV_B16_t16_e64 0, -1, 0, implicit $exec
219+
%1:vgpr_16 = V_MOV_B16_t16_e64 0, -1, 0, implicit $exec
220+
%2:vgpr_32 = REG_SEQUENCE %0, %subreg.lo16, %1, %subreg.hi16
221+
%3:vgpr_32 = nofpexcept V_MAX_F32_e64 0, %2, 0, %2, 0, 0, implicit $mode, implicit $exec
222+
$vgpr0 = COPY %3
223+
S_ENDPGM 0, implicit $vgpr0
224+
...
225+
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// This is a design that was in the MIGraphX CI that was previously failing
2+
// here: https://ontrack-internal.amd.com/browse/SWDEV-558297
3+
4+
// RUN: rocmlir-gen -fut mlir_attention --arch %arch --clone-harness %s | rocmlir-driver -kernel-pipeline=migraphx | rocmlir-driver -host-pipeline=migraphx,highlevel | rocmlir-gen -ph -rand 1 -rand_type float -fut mlir_attention_wrapper --verifier clone - | rocmlir-driver -host-pipeline mhal -kernel-pipeline full | xmir-runner --shared-libs=%linalg_test_lib_dir/libmlir_rocm_runtime%shlibext,%conv_validation_wrapper_library_dir/libconv-validation-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_float16_utils%shlibext,%linalg_test_lib_dir/libmlir_c_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_async_runtime%shlibext --entry-point-result=void | FileCheck %s
5+
6+
module {
7+
// CHECK: [1 1 1]
8+
func.func @mlir_attention(
9+
%arg0: !migraphx.shaped<1x1x12xf16, 12x12x1>,
10+
%arg1: !migraphx.shaped<1x2x4x2xf16, 16x8x2x1>,
11+
%arg2: !migraphx.shaped<1x2x4x2xf16, 16x8x2x1>,
12+
%arg3: !migraphx.shaped<1x1xsi32, 1x1>
13+
) -> !migraphx.shaped<1x1x4xf16, 4x4x1> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
14+
%0 = migraphx.literal(dense<[0, 1, 2, 3]> : tensor<4xsi32>) : <4xsi32, 1>
15+
%1 = migraphx.literal(dense<0xFC00> : tensor<1xf16>) : <1xf16, 1>
16+
%2 = migraphx.literal(dense<1.000000e+00> : tensor<1xf16>) : <1xf16, 1>
17+
%3 = migraphx.reshape %arg0 {dims = [1, 1, 6, 2]} : <1x1x12xf16, 12x12x1> -> <1x1x6x2xf16, 12x12x2x1>
18+
%4 = migraphx.transpose %3 {permutation = [0, 2, 1, 3]} : <1x1x6x2xf16, 12x12x2x1> -> <1x6x1x2xf16, 12x2x12x1>
19+
%5 = migraphx.multibroadcast %arg3 {out_dyn_dims = [], out_lens = [1, 2]} : <1x1xsi32, 1x1> -> <1x2xsi32, 1x0>
20+
%6 = migraphx.slice %4 {axes = [1], ends = [2], starts = [0]} : <1x6x1x2xf16, 12x2x12x1> -> <1x2x1x2xf16, 12x2x12x1>
21+
%7 = migraphx.transpose %arg1 {permutation = [0, 1, 3, 2]} : <1x2x4x2xf16, 16x8x2x1> -> <1x2x2x4xf16, 16x8x1x2>
22+
%8 = migraphx.dot %6, %7 : <1x2x1x2xf16, 12x2x12x1>, <1x2x2x4xf16, 16x8x1x2> -> <1x2x1x4xf16, 8x4x4x1>
23+
%9 = migraphx.multibroadcast %0 {out_dyn_dims = [], out_lens = [1, 2, 1, 4]} : <4xsi32, 1> -> <1x2x1x4xsi32, 0x0x0x1>
24+
%10 = migraphx.multibroadcast %1 {out_dyn_dims = [], out_lens = [1, 2, 1, 4]} : <1xf16, 1> -> <1x2x1x4xf16, 0x0x0x0>
25+
%11 = migraphx.multibroadcast %2 {out_dyn_dims = [], out_lens = [1, 2, 1, 4]} : <1xf16, 1> -> <1x2x1x4xf16, 0x0x0x0>
26+
%12 = migraphx.mul %8, %11 : <1x2x1x4xf16, 8x4x4x1>, <1x2x1x4xf16, 0x0x0x0> -> <1x2x1x4xf16, 8x4x4x1>
27+
%13 = migraphx.reshape %5 {dims = [1, 2, 1, 1]} : <1x2xsi32, 1x0> -> <1x2x1x1xsi32, 2x1x1x1>
28+
%14 = migraphx.multibroadcast %13 {out_dyn_dims = [], out_lens = [1, 2, 1, 4]} : <1x2x1x1xsi32, 2x1x1x1> -> <1x2x1x4xsi32, 2x1x1x0>
29+
%15 = migraphx.greater %9, %14 : <1x2x1x4xsi32, 0x0x0x1>, <1x2x1x4xsi32, 2x1x1x0> -> <1x2x1x4xsi32, 8x4x4x1>
30+
%16 = migraphx.convert %15 {target_type = 0 : i64} : <1x2x1x4xsi32, 8x4x4x1> to <1x2x1x4xsi8, 8x4x4x1>
31+
%17 = migraphx.where %16, %10, %12 : <1x2x1x4xsi8, 8x4x4x1>, <1x2x1x4xf16, 0x0x0x0>, <1x2x1x4xf16, 8x4x4x1> -> <1x2x1x4xf16, 8x4x4x1>
32+
%18 = migraphx.softmax %17 {axis = 3 : i64} : <1x2x1x4xf16, 8x4x4x1> -> <1x2x1x4xf16, 8x4x4x1>
33+
%19 = migraphx.dot %18, %arg2 : <1x2x1x4xf16, 8x4x4x1>, <1x2x4x2xf16, 16x8x2x1> -> <1x2x1x2xf16, 4x2x2x1>
34+
%20 = migraphx.transpose %19 {permutation = [0, 2, 1, 3]} : <1x2x1x2xf16, 4x2x2x1> -> <1x1x2x2xf16, 4x2x2x1>
35+
%21 = migraphx.reshape %20 {dims = [1, 1, 4]} : <1x1x2x2xf16, 4x2x2x1> -> <1x1x4xf16, 4x4x1>
36+
return %21 : !migraphx.shaped<1x1x4xf16, 4x4x1>
37+
}
38+
}
39+

0 commit comments

Comments
 (0)