-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[NVPTX] Customize getScalarizationOverhead #128077
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,8 +16,9 @@ | |
| #ifndef LLVM_LIB_TARGET_NVPTX_NVPTXTARGETTRANSFORMINFO_H | ||
| #define LLVM_LIB_TARGET_NVPTX_NVPTXTARGETTRANSFORMINFO_H | ||
|
|
||
| #include "NVPTXTargetMachine.h" | ||
| #include "MCTargetDesc/NVPTXBaseInfo.h" | ||
| #include "NVPTXTargetMachine.h" | ||
| #include "NVPTXUtilities.h" | ||
| #include "llvm/Analysis/TargetTransformInfo.h" | ||
| #include "llvm/CodeGen/BasicTTIImpl.h" | ||
| #include "llvm/CodeGen/TargetLowering.h" | ||
|
|
@@ -104,6 +105,42 @@ class NVPTXTTIImpl : public BasicTTIImplBase<NVPTXTTIImpl> { | |
| TTI::OperandValueInfo Op2Info = {TTI::OK_AnyValue, TTI::OP_None}, | ||
| ArrayRef<const Value *> Args = {}, const Instruction *CxtI = nullptr); | ||
|
|
||
| InstructionCost getScalarizationOverhead(VectorType *InTy, | ||
| const APInt &DemandedElts, | ||
| bool Insert, bool Extract, | ||
| TTI::TargetCostKind CostKind, | ||
| ArrayRef<Value *> VL = {}) { | ||
| if (!InTy->getElementCount().isFixed()) | ||
| return InstructionCost::getInvalid(); | ||
|
|
||
| auto VT = getTLI()->getValueType(DL, InTy); | ||
| auto NumElements = InTy->getElementCount().getFixedValue(); | ||
| InstructionCost Cost = 0; | ||
| if (Insert && !VL.empty()) { | ||
| bool AllConstant = all_of(seq(NumElements), [&](int Idx) { | ||
| return !DemandedElts[Idx] || isa<Constant>(VL[Idx]); | ||
| }); | ||
| if (AllConstant) { | ||
| Cost += TTI::TCC_Free; | ||
| Insert = false; | ||
| } | ||
| } | ||
| if (Insert && Isv2x16VT(VT)) { | ||
| // Can be built in a single mov | ||
| Cost += 1; | ||
| Insert = false; | ||
| } | ||
| if (Insert && VT == MVT::v4i8) { | ||
| InstructionCost Cost = 3; // 3 x PRMT | ||
| for (auto Idx : seq(NumElements)) | ||
| if (DemandedElts[Idx]) | ||
| Cost += 1; // zext operand to i32 | ||
|
||
| Insert = false; | ||
| } | ||
| return Cost + BaseT::getScalarizationOverhead(InTy, DemandedElts, Insert, | ||
| Extract, CostKind, VL); | ||
| } | ||
|
|
||
| void getUnrollingPreferences(Loop *L, ScalarEvolution &SE, | ||
| TTI::UnrollingPreferences &UP, | ||
| OptimizationRemarkEmitter *ORE); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,59 +1,118 @@ | ||
| ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py | ||
| ; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_70 | FileCheck %s | ||
| ; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_40 | FileCheck %s -check-prefix=NOVECTOR | ||
| ; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_90 | FileCheck %s -check-prefix=VECTOR | ||
| ; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_80 | FileCheck %s -check-prefix=VECTOR | ||
| ; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_70 | FileCheck %s -check-prefix=VECTOR | ||
| ; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_50 | FileCheck %s -check-prefix=NOVECTOR | ||
|
|
||
| define void @fusion(ptr noalias nocapture align 256 dereferenceable(19267584) %arg, ptr noalias nocapture readonly align 256 dereferenceable(19267584) %arg1, i32 %arg2, i32 %arg3) local_unnamed_addr #0 { | ||
| ; CHECK-LABEL: @fusion( | ||
| ; CHECK-NEXT: [[TMP:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6 | ||
| ; CHECK-NEXT: [[TMP4:%.*]] = or i32 [[TMP]], [[ARG3:%.*]] | ||
| ; CHECK-NEXT: [[TMP5:%.*]] = shl nuw nsw i32 [[TMP4]], 2 | ||
| ; CHECK-NEXT: [[TMP6:%.*]] = zext i32 [[TMP5]] to i64 | ||
| ; CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP6]] | ||
| ; CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP6]] | ||
| ; CHECK-NEXT: [[TMP1:%.*]] = load <2 x half>, ptr [[TMP11]], align 8 | ||
| ; CHECK-NEXT: [[TMP2:%.*]] = fmul fast <2 x half> [[TMP1]], splat (half 0xH5380) | ||
| ; CHECK-NEXT: [[TMP3:%.*]] = fadd fast <2 x half> [[TMP2]], splat (half 0xH57F0) | ||
| ; CHECK-NEXT: store <2 x half> [[TMP3]], ptr [[TMP16]], align 8 | ||
| ; CHECK-NEXT: ret void | ||
| ; VECTOR-LABEL: @fusion( | ||
| ; VECTOR-NEXT: [[TMP1:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6 | ||
| ; VECTOR-NEXT: [[TMP2:%.*]] = or i32 [[TMP1]], [[ARG3:%.*]] | ||
| ; VECTOR-NEXT: [[TMP3:%.*]] = shl nuw nsw i32 [[TMP2]], 2 | ||
| ; VECTOR-NEXT: [[TMP4:%.*]] = zext i32 [[TMP3]] to i64 | ||
| ; VECTOR-NEXT: [[TMP5:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP4]] | ||
| ; VECTOR-NEXT: [[TMP6:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP4]] | ||
| ; VECTOR-NEXT: [[TMP7:%.*]] = load <2 x half>, ptr [[TMP5]], align 8 | ||
| ; VECTOR-NEXT: [[TMP8:%.*]] = fmul fast <2 x half> [[TMP7]], splat (half 0xH5380) | ||
| ; VECTOR-NEXT: [[TMP9:%.*]] = fadd fast <2 x half> [[TMP8]], splat (half 0xH57F0) | ||
| ; VECTOR-NEXT: store <2 x half> [[TMP9]], ptr [[TMP6]], align 8 | ||
| ; VECTOR-NEXT: ret void | ||
| ; | ||
| ; NOVECTOR-LABEL: @fusion( | ||
| ; NOVECTOR-NEXT: [[TMP:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6 | ||
| ; NOVECTOR-NEXT: [[TMP4:%.*]] = or i32 [[TMP]], [[ARG3:%.*]] | ||
| ; NOVECTOR-NEXT: [[TMP5:%.*]] = shl nuw nsw i32 [[TMP4]], 2 | ||
| ; NOVECTOR-NEXT: [[TMP6:%.*]] = zext i32 [[TMP5]] to i64 | ||
| ; NOVECTOR-NEXT: [[TMP7:%.*]] = or disjoint i64 [[TMP6]], 1 | ||
| ; NOVECTOR-NEXT: [[TMP11:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP6]] | ||
| ; NOVECTOR-NEXT: [[TMP12:%.*]] = load half, ptr [[TMP11]], align 8 | ||
| ; NOVECTOR-NEXT: [[TMP1:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6 | ||
| ; NOVECTOR-NEXT: [[TMP2:%.*]] = or i32 [[TMP1]], [[ARG3:%.*]] | ||
| ; NOVECTOR-NEXT: [[TMP3:%.*]] = shl nuw nsw i32 [[TMP2]], 2 | ||
| ; NOVECTOR-NEXT: [[TMP4:%.*]] = zext i32 [[TMP3]] to i64 | ||
| ; NOVECTOR-NEXT: [[TMP10:%.*]] = or disjoint i64 [[TMP4]], 1 | ||
| ; NOVECTOR-NEXT: [[TMP5:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP4]] | ||
| ; NOVECTOR-NEXT: [[TMP7:%.*]] = load half, ptr [[TMP5]], align 8 | ||
| ; NOVECTOR-NEXT: [[TMP8:%.*]] = fmul fast half [[TMP7]], 0xH5380 | ||
| ; NOVECTOR-NEXT: [[TMP9:%.*]] = fadd fast half [[TMP8]], 0xH57F0 | ||
| ; NOVECTOR-NEXT: [[TMP6:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP4]] | ||
| ; NOVECTOR-NEXT: store half [[TMP9]], ptr [[TMP6]], align 8 | ||
| ; NOVECTOR-NEXT: [[TMP11:%.*]] = getelementptr inbounds half, ptr [[ARG1]], i64 [[TMP10]] | ||
| ; NOVECTOR-NEXT: [[TMP12:%.*]] = load half, ptr [[TMP11]], align 2 | ||
| ; NOVECTOR-NEXT: [[TMP13:%.*]] = fmul fast half [[TMP12]], 0xH5380 | ||
| ; NOVECTOR-NEXT: [[TMP14:%.*]] = fadd fast half [[TMP13]], 0xH57F0 | ||
| ; NOVECTOR-NEXT: [[TMP16:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP6]] | ||
| ; NOVECTOR-NEXT: store half [[TMP14]], ptr [[TMP16]], align 8 | ||
| ; NOVECTOR-NEXT: [[TMP17:%.*]] = getelementptr inbounds half, ptr [[ARG1]], i64 [[TMP7]] | ||
| ; NOVECTOR-NEXT: [[TMP18:%.*]] = load half, ptr [[TMP17]], align 2 | ||
| ; NOVECTOR-NEXT: [[TMP19:%.*]] = fmul fast half [[TMP18]], 0xH5380 | ||
| ; NOVECTOR-NEXT: [[TMP20:%.*]] = fadd fast half [[TMP19]], 0xH57F0 | ||
| ; NOVECTOR-NEXT: [[TMP21:%.*]] = getelementptr inbounds half, ptr [[ARG]], i64 [[TMP7]] | ||
| ; NOVECTOR-NEXT: store half [[TMP20]], ptr [[TMP21]], align 2 | ||
| ; NOVECTOR-NEXT: [[TMP15:%.*]] = getelementptr inbounds half, ptr [[ARG]], i64 [[TMP10]] | ||
| ; NOVECTOR-NEXT: store half [[TMP14]], ptr [[TMP15]], align 2 | ||
| ; NOVECTOR-NEXT: ret void | ||
| ; | ||
| %tmp = shl nuw nsw i32 %arg2, 6 | ||
| %tmp4 = or i32 %tmp, %arg3 | ||
| %tmp5 = shl nuw nsw i32 %tmp4, 2 | ||
| %tmp6 = zext i32 %tmp5 to i64 | ||
| %tmp7 = or disjoint i64 %tmp6, 1 | ||
| %tmp11 = getelementptr inbounds half, ptr %arg1, i64 %tmp6 | ||
| %tmp12 = load half, ptr %tmp11, align 8 | ||
| %tmp13 = fmul fast half %tmp12, 0xH5380 | ||
| %tmp14 = fadd fast half %tmp13, 0xH57F0 | ||
| %tmp16 = getelementptr inbounds half, ptr %arg, i64 %tmp6 | ||
| store half %tmp14, ptr %tmp16, align 8 | ||
| %tmp17 = getelementptr inbounds half, ptr %arg1, i64 %tmp7 | ||
| %tmp18 = load half, ptr %tmp17, align 2 | ||
| %tmp19 = fmul fast half %tmp18, 0xH5380 | ||
| %tmp20 = fadd fast half %tmp19, 0xH57F0 | ||
| %tmp21 = getelementptr inbounds half, ptr %arg, i64 %tmp7 | ||
| store half %tmp20, ptr %tmp21, align 2 | ||
| %1 = shl nuw nsw i32 %arg2, 6 | ||
| %4 = or i32 %1, %arg3 | ||
| %5 = shl nuw nsw i32 %4, 2 | ||
| %6 = zext i32 %5 to i64 | ||
| %7 = or disjoint i64 %6, 1 | ||
| %11 = getelementptr inbounds half, ptr %arg1, i64 %6 | ||
| %12 = load half, ptr %11, align 8 | ||
| %13 = fmul fast half %12, 0xH5380 | ||
| %14 = fadd fast half %13, 0xH57F0 | ||
| %16 = getelementptr inbounds half, ptr %arg, i64 %6 | ||
| store half %14, ptr %16, align 8 | ||
| %17 = getelementptr inbounds half, ptr %arg1, i64 %7 | ||
| %18 = load half, ptr %17, align 2 | ||
| %19 = fmul fast half %18, 0xH5380 | ||
| %20 = fadd fast half %19, 0xH57F0 | ||
| %21 = getelementptr inbounds half, ptr %arg, i64 %7 | ||
| store half %20, ptr %21, align 2 | ||
| ret void | ||
| } | ||
|
|
||
| define ptx_kernel void @add_f16(ptr addrspace(1) %0, { half, half } %1, { half, half } %2) { | ||
| ; VECTOR-LABEL: @add_f16( | ||
| ; VECTOR-NEXT: [[TMP4:%.*]] = extractvalue { half, half } [[TMP1:%.*]], 0 | ||
| ; VECTOR-NEXT: [[TMP5:%.*]] = extractvalue { half, half } [[TMP1]], 1 | ||
| ; VECTOR-NEXT: [[TMP6:%.*]] = extractvalue { half, half } [[TMP2:%.*]], 0 | ||
| ; VECTOR-NEXT: [[TMP7:%.*]] = extractvalue { half, half } [[TMP2]], 1 | ||
| ; VECTOR-NEXT: [[TMP8:%.*]] = insertelement <2 x half> poison, half [[TMP4]], i32 0 | ||
| ; VECTOR-NEXT: [[TMP9:%.*]] = insertelement <2 x half> [[TMP8]], half [[TMP5]], i32 1 | ||
| ; VECTOR-NEXT: [[TMP10:%.*]] = insertelement <2 x half> poison, half [[TMP6]], i32 0 | ||
| ; VECTOR-NEXT: [[TMP11:%.*]] = insertelement <2 x half> [[TMP10]], half [[TMP7]], i32 1 | ||
| ; VECTOR-NEXT: [[TMP12:%.*]] = fadd <2 x half> [[TMP9]], [[TMP11]] | ||
| ; VECTOR-NEXT: [[TMP13:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x() | ||
| ; VECTOR-NEXT: [[TMP14:%.*]] = shl i32 [[TMP13]], 1 | ||
| ; VECTOR-NEXT: [[TMP15:%.*]] = and i32 [[TMP14]], 62 | ||
| ; VECTOR-NEXT: [[TMP16:%.*]] = zext nneg i32 [[TMP15]] to i64 | ||
| ; VECTOR-NEXT: [[TMP17:%.*]] = getelementptr half, ptr addrspace(1) [[TMP0:%.*]], i64 [[TMP16]] | ||
| ; VECTOR-NEXT: store <2 x half> [[TMP12]], ptr addrspace(1) [[TMP17]], align 4 | ||
| ; VECTOR-NEXT: ret void | ||
| ; | ||
| ; NOVECTOR-LABEL: @add_f16( | ||
| ; NOVECTOR-NEXT: [[TMP4:%.*]] = extractvalue { half, half } [[TMP1:%.*]], 0 | ||
| ; NOVECTOR-NEXT: [[TMP5:%.*]] = extractvalue { half, half } [[TMP1]], 1 | ||
| ; NOVECTOR-NEXT: [[TMP6:%.*]] = extractvalue { half, half } [[TMP2:%.*]], 0 | ||
| ; NOVECTOR-NEXT: [[TMP7:%.*]] = extractvalue { half, half } [[TMP2]], 1 | ||
| ; NOVECTOR-NEXT: [[TMP8:%.*]] = fadd half [[TMP4]], [[TMP6]] | ||
| ; NOVECTOR-NEXT: [[TMP9:%.*]] = fadd half [[TMP5]], [[TMP7]] | ||
| ; NOVECTOR-NEXT: [[TMP13:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x() | ||
| ; NOVECTOR-NEXT: [[TMP14:%.*]] = shl i32 [[TMP13]], 1 | ||
| ; NOVECTOR-NEXT: [[TMP15:%.*]] = and i32 [[TMP14]], 62 | ||
| ; NOVECTOR-NEXT: [[TMP16:%.*]] = zext nneg i32 [[TMP15]] to i64 | ||
| ; NOVECTOR-NEXT: [[TMP17:%.*]] = getelementptr half, ptr addrspace(1) [[TMP0:%.*]], i64 [[TMP16]] | ||
| ; NOVECTOR-NEXT: [[TMP19:%.*]] = insertelement <2 x half> poison, half [[TMP8]], i64 0 | ||
| ; NOVECTOR-NEXT: [[TMP12:%.*]] = insertelement <2 x half> [[TMP19]], half [[TMP9]], i64 1 | ||
| ; NOVECTOR-NEXT: store <2 x half> [[TMP12]], ptr addrspace(1) [[TMP17]], align 4 | ||
| ; NOVECTOR-NEXT: ret void | ||
| ; | ||
| %5 = extractvalue { half, half } %1, 0 | ||
| %6 = extractvalue { half, half } %1, 1 | ||
| %7 = extractvalue { half, half } %2, 0 | ||
| %8 = extractvalue { half, half } %2, 1 | ||
| %9 = fadd half %5, %7 | ||
| %10 = fadd half %6, %8 | ||
| %11 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x() | ||
| %12 = shl i32 %11, 1 | ||
| %13 = and i32 %12, 62 | ||
| %14 = zext nneg i32 %13 to i64 | ||
| %15 = getelementptr half, ptr addrspace(1) %0, i64 %14 | ||
| %18 = insertelement <2 x half> poison, half %9, i64 0 | ||
| %19 = insertelement <2 x half> %18, half %10, i64 1 | ||
| store <2 x half> %19, ptr addrspace(1) %15, align 4 | ||
| ret void | ||
| } | ||
|
|
||
| ; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) | ||
| declare noundef i32 @llvm.nvvm.read.ptx.sreg.tid.x() #1 | ||
|
|
||
| attributes #0 = { nounwind } | ||
| attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } |
Uh oh!
There was an error while loading. Please reload this page.