Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
#ifndef LLVM_LIB_TARGET_NVPTX_NVPTXTARGETTRANSFORMINFO_H
#define LLVM_LIB_TARGET_NVPTX_NVPTXTARGETTRANSFORMINFO_H

#include "NVPTXTargetMachine.h"
#include "MCTargetDesc/NVPTXBaseInfo.h"
#include "NVPTXTargetMachine.h"
#include "NVPTXUtilities.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/CodeGen/BasicTTIImpl.h"
#include "llvm/CodeGen/TargetLowering.h"
Expand Down Expand Up @@ -104,6 +105,42 @@ class NVPTXTTIImpl : public BasicTTIImplBase<NVPTXTTIImpl> {
TTI::OperandValueInfo Op2Info = {TTI::OK_AnyValue, TTI::OP_None},
ArrayRef<const Value *> Args = {}, const Instruction *CxtI = nullptr);

InstructionCost getScalarizationOverhead(VectorType *InTy,
const APInt &DemandedElts,
bool Insert, bool Extract,
TTI::TargetCostKind CostKind,
ArrayRef<Value *> VL = {}) {
if (!InTy->getElementCount().isFixed())
return InstructionCost::getInvalid();

auto VT = getTLI()->getValueType(DL, InTy);
auto NumElements = InTy->getElementCount().getFixedValue();
InstructionCost Cost = 0;
if (Insert && !VL.empty()) {
bool AllConstant = all_of(seq(NumElements), [&](int Idx) {
return !DemandedElts[Idx] || isa<Constant>(VL[Idx]);
});
if (AllConstant) {
Cost += TTI::TCC_Free;
Insert = false;
}
}
if (Insert && Isv2x16VT(VT)) {
// Can be built in a single mov
Cost += 1;
Insert = false;
}
if (Insert && VT == MVT::v4i8) {
InstructionCost Cost = 3; // 3 x PRMT
for (auto Idx : seq(NumElements))
if (DemandedElts[Idx])
Cost += 1; // zext operand to i32
Comment on lines +135 to +137
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not very familiar with this API. Could you explain why we need to incur the cost of the zext? Doesn't prmt emit an i32?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It occurs to me that we don't actually need the zero- part we really just need to extend the register type, but unfortunately I can't think of a way to really express that in the ptx. Certainly the current lowering sometimes produces SASS to zero the top part of the register. e.g. see the LOP3 instruction in this:
https://godbolt.org/z/EPjjzv4cz

It seems for the load ptxas knows the top part is zeroed, but if it's unsure it will mask the lower bits.

Insert = false;
}
return Cost + BaseT::getScalarizationOverhead(InTy, DemandedElts, Insert,
Extract, CostKind, VL);
}

void getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
TTI::UnrollingPreferences &UP,
OptimizationRemarkEmitter *ORE);
Expand Down
151 changes: 105 additions & 46 deletions llvm/test/Transforms/SLPVectorizer/NVPTX/v2f16.ll
Original file line number Diff line number Diff line change
@@ -1,59 +1,118 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_70 | FileCheck %s
; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_40 | FileCheck %s -check-prefix=NOVECTOR
; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_90 | FileCheck %s -check-prefix=VECTOR
; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_80 | FileCheck %s -check-prefix=VECTOR
; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_70 | FileCheck %s -check-prefix=VECTOR
; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_50 | FileCheck %s -check-prefix=NOVECTOR

define void @fusion(ptr noalias nocapture align 256 dereferenceable(19267584) %arg, ptr noalias nocapture readonly align 256 dereferenceable(19267584) %arg1, i32 %arg2, i32 %arg3) local_unnamed_addr #0 {
; CHECK-LABEL: @fusion(
; CHECK-NEXT: [[TMP:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6
; CHECK-NEXT: [[TMP4:%.*]] = or i32 [[TMP]], [[ARG3:%.*]]
; CHECK-NEXT: [[TMP5:%.*]] = shl nuw nsw i32 [[TMP4]], 2
; CHECK-NEXT: [[TMP6:%.*]] = zext i32 [[TMP5]] to i64
; CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP6]]
; CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP6]]
; CHECK-NEXT: [[TMP1:%.*]] = load <2 x half>, ptr [[TMP11]], align 8
; CHECK-NEXT: [[TMP2:%.*]] = fmul fast <2 x half> [[TMP1]], splat (half 0xH5380)
; CHECK-NEXT: [[TMP3:%.*]] = fadd fast <2 x half> [[TMP2]], splat (half 0xH57F0)
; CHECK-NEXT: store <2 x half> [[TMP3]], ptr [[TMP16]], align 8
; CHECK-NEXT: ret void
; VECTOR-LABEL: @fusion(
; VECTOR-NEXT: [[TMP1:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6
; VECTOR-NEXT: [[TMP2:%.*]] = or i32 [[TMP1]], [[ARG3:%.*]]
; VECTOR-NEXT: [[TMP3:%.*]] = shl nuw nsw i32 [[TMP2]], 2
; VECTOR-NEXT: [[TMP4:%.*]] = zext i32 [[TMP3]] to i64
; VECTOR-NEXT: [[TMP5:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP4]]
; VECTOR-NEXT: [[TMP6:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP4]]
; VECTOR-NEXT: [[TMP7:%.*]] = load <2 x half>, ptr [[TMP5]], align 8
; VECTOR-NEXT: [[TMP8:%.*]] = fmul fast <2 x half> [[TMP7]], splat (half 0xH5380)
; VECTOR-NEXT: [[TMP9:%.*]] = fadd fast <2 x half> [[TMP8]], splat (half 0xH57F0)
; VECTOR-NEXT: store <2 x half> [[TMP9]], ptr [[TMP6]], align 8
; VECTOR-NEXT: ret void
;
; NOVECTOR-LABEL: @fusion(
; NOVECTOR-NEXT: [[TMP:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6
; NOVECTOR-NEXT: [[TMP4:%.*]] = or i32 [[TMP]], [[ARG3:%.*]]
; NOVECTOR-NEXT: [[TMP5:%.*]] = shl nuw nsw i32 [[TMP4]], 2
; NOVECTOR-NEXT: [[TMP6:%.*]] = zext i32 [[TMP5]] to i64
; NOVECTOR-NEXT: [[TMP7:%.*]] = or disjoint i64 [[TMP6]], 1
; NOVECTOR-NEXT: [[TMP11:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP6]]
; NOVECTOR-NEXT: [[TMP12:%.*]] = load half, ptr [[TMP11]], align 8
; NOVECTOR-NEXT: [[TMP1:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6
; NOVECTOR-NEXT: [[TMP2:%.*]] = or i32 [[TMP1]], [[ARG3:%.*]]
; NOVECTOR-NEXT: [[TMP3:%.*]] = shl nuw nsw i32 [[TMP2]], 2
; NOVECTOR-NEXT: [[TMP4:%.*]] = zext i32 [[TMP3]] to i64
; NOVECTOR-NEXT: [[TMP10:%.*]] = or disjoint i64 [[TMP4]], 1
; NOVECTOR-NEXT: [[TMP5:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP4]]
; NOVECTOR-NEXT: [[TMP7:%.*]] = load half, ptr [[TMP5]], align 8
; NOVECTOR-NEXT: [[TMP8:%.*]] = fmul fast half [[TMP7]], 0xH5380
; NOVECTOR-NEXT: [[TMP9:%.*]] = fadd fast half [[TMP8]], 0xH57F0
; NOVECTOR-NEXT: [[TMP6:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP4]]
; NOVECTOR-NEXT: store half [[TMP9]], ptr [[TMP6]], align 8
; NOVECTOR-NEXT: [[TMP11:%.*]] = getelementptr inbounds half, ptr [[ARG1]], i64 [[TMP10]]
; NOVECTOR-NEXT: [[TMP12:%.*]] = load half, ptr [[TMP11]], align 2
; NOVECTOR-NEXT: [[TMP13:%.*]] = fmul fast half [[TMP12]], 0xH5380
; NOVECTOR-NEXT: [[TMP14:%.*]] = fadd fast half [[TMP13]], 0xH57F0
; NOVECTOR-NEXT: [[TMP16:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP6]]
; NOVECTOR-NEXT: store half [[TMP14]], ptr [[TMP16]], align 8
; NOVECTOR-NEXT: [[TMP17:%.*]] = getelementptr inbounds half, ptr [[ARG1]], i64 [[TMP7]]
; NOVECTOR-NEXT: [[TMP18:%.*]] = load half, ptr [[TMP17]], align 2
; NOVECTOR-NEXT: [[TMP19:%.*]] = fmul fast half [[TMP18]], 0xH5380
; NOVECTOR-NEXT: [[TMP20:%.*]] = fadd fast half [[TMP19]], 0xH57F0
; NOVECTOR-NEXT: [[TMP21:%.*]] = getelementptr inbounds half, ptr [[ARG]], i64 [[TMP7]]
; NOVECTOR-NEXT: store half [[TMP20]], ptr [[TMP21]], align 2
; NOVECTOR-NEXT: [[TMP15:%.*]] = getelementptr inbounds half, ptr [[ARG]], i64 [[TMP10]]
; NOVECTOR-NEXT: store half [[TMP14]], ptr [[TMP15]], align 2
; NOVECTOR-NEXT: ret void
;
%tmp = shl nuw nsw i32 %arg2, 6
%tmp4 = or i32 %tmp, %arg3
%tmp5 = shl nuw nsw i32 %tmp4, 2
%tmp6 = zext i32 %tmp5 to i64
%tmp7 = or disjoint i64 %tmp6, 1
%tmp11 = getelementptr inbounds half, ptr %arg1, i64 %tmp6
%tmp12 = load half, ptr %tmp11, align 8
%tmp13 = fmul fast half %tmp12, 0xH5380
%tmp14 = fadd fast half %tmp13, 0xH57F0
%tmp16 = getelementptr inbounds half, ptr %arg, i64 %tmp6
store half %tmp14, ptr %tmp16, align 8
%tmp17 = getelementptr inbounds half, ptr %arg1, i64 %tmp7
%tmp18 = load half, ptr %tmp17, align 2
%tmp19 = fmul fast half %tmp18, 0xH5380
%tmp20 = fadd fast half %tmp19, 0xH57F0
%tmp21 = getelementptr inbounds half, ptr %arg, i64 %tmp7
store half %tmp20, ptr %tmp21, align 2
%1 = shl nuw nsw i32 %arg2, 6
%4 = or i32 %1, %arg3
%5 = shl nuw nsw i32 %4, 2
%6 = zext i32 %5 to i64
%7 = or disjoint i64 %6, 1
%11 = getelementptr inbounds half, ptr %arg1, i64 %6
%12 = load half, ptr %11, align 8
%13 = fmul fast half %12, 0xH5380
%14 = fadd fast half %13, 0xH57F0
%16 = getelementptr inbounds half, ptr %arg, i64 %6
store half %14, ptr %16, align 8
%17 = getelementptr inbounds half, ptr %arg1, i64 %7
%18 = load half, ptr %17, align 2
%19 = fmul fast half %18, 0xH5380
%20 = fadd fast half %19, 0xH57F0
%21 = getelementptr inbounds half, ptr %arg, i64 %7
store half %20, ptr %21, align 2
ret void
}

define ptx_kernel void @add_f16(ptr addrspace(1) %0, { half, half } %1, { half, half } %2) {
; VECTOR-LABEL: @add_f16(
; VECTOR-NEXT: [[TMP4:%.*]] = extractvalue { half, half } [[TMP1:%.*]], 0
; VECTOR-NEXT: [[TMP5:%.*]] = extractvalue { half, half } [[TMP1]], 1
; VECTOR-NEXT: [[TMP6:%.*]] = extractvalue { half, half } [[TMP2:%.*]], 0
; VECTOR-NEXT: [[TMP7:%.*]] = extractvalue { half, half } [[TMP2]], 1
; VECTOR-NEXT: [[TMP8:%.*]] = insertelement <2 x half> poison, half [[TMP4]], i32 0
; VECTOR-NEXT: [[TMP9:%.*]] = insertelement <2 x half> [[TMP8]], half [[TMP5]], i32 1
; VECTOR-NEXT: [[TMP10:%.*]] = insertelement <2 x half> poison, half [[TMP6]], i32 0
; VECTOR-NEXT: [[TMP11:%.*]] = insertelement <2 x half> [[TMP10]], half [[TMP7]], i32 1
; VECTOR-NEXT: [[TMP12:%.*]] = fadd <2 x half> [[TMP9]], [[TMP11]]
; VECTOR-NEXT: [[TMP13:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
; VECTOR-NEXT: [[TMP14:%.*]] = shl i32 [[TMP13]], 1
; VECTOR-NEXT: [[TMP15:%.*]] = and i32 [[TMP14]], 62
; VECTOR-NEXT: [[TMP16:%.*]] = zext nneg i32 [[TMP15]] to i64
; VECTOR-NEXT: [[TMP17:%.*]] = getelementptr half, ptr addrspace(1) [[TMP0:%.*]], i64 [[TMP16]]
; VECTOR-NEXT: store <2 x half> [[TMP12]], ptr addrspace(1) [[TMP17]], align 4
; VECTOR-NEXT: ret void
;
; NOVECTOR-LABEL: @add_f16(
; NOVECTOR-NEXT: [[TMP4:%.*]] = extractvalue { half, half } [[TMP1:%.*]], 0
; NOVECTOR-NEXT: [[TMP5:%.*]] = extractvalue { half, half } [[TMP1]], 1
; NOVECTOR-NEXT: [[TMP6:%.*]] = extractvalue { half, half } [[TMP2:%.*]], 0
; NOVECTOR-NEXT: [[TMP7:%.*]] = extractvalue { half, half } [[TMP2]], 1
; NOVECTOR-NEXT: [[TMP8:%.*]] = fadd half [[TMP4]], [[TMP6]]
; NOVECTOR-NEXT: [[TMP9:%.*]] = fadd half [[TMP5]], [[TMP7]]
; NOVECTOR-NEXT: [[TMP13:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
; NOVECTOR-NEXT: [[TMP14:%.*]] = shl i32 [[TMP13]], 1
; NOVECTOR-NEXT: [[TMP15:%.*]] = and i32 [[TMP14]], 62
; NOVECTOR-NEXT: [[TMP16:%.*]] = zext nneg i32 [[TMP15]] to i64
; NOVECTOR-NEXT: [[TMP17:%.*]] = getelementptr half, ptr addrspace(1) [[TMP0:%.*]], i64 [[TMP16]]
; NOVECTOR-NEXT: [[TMP19:%.*]] = insertelement <2 x half> poison, half [[TMP8]], i64 0
; NOVECTOR-NEXT: [[TMP12:%.*]] = insertelement <2 x half> [[TMP19]], half [[TMP9]], i64 1
; NOVECTOR-NEXT: store <2 x half> [[TMP12]], ptr addrspace(1) [[TMP17]], align 4
; NOVECTOR-NEXT: ret void
;
%5 = extractvalue { half, half } %1, 0
%6 = extractvalue { half, half } %1, 1
%7 = extractvalue { half, half } %2, 0
%8 = extractvalue { half, half } %2, 1
%9 = fadd half %5, %7
%10 = fadd half %6, %8
%11 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%12 = shl i32 %11, 1
%13 = and i32 %12, 62
%14 = zext nneg i32 %13 to i64
%15 = getelementptr half, ptr addrspace(1) %0, i64 %14
%18 = insertelement <2 x half> poison, half %9, i64 0
%19 = insertelement <2 x half> %18, half %10, i64 1
store <2 x half> %19, ptr addrspace(1) %15, align 4
ret void
}

; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare noundef i32 @llvm.nvvm.read.ptx.sreg.tid.x() #1

attributes #0 = { nounwind }
attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }