diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp index 702206b8e0dc5..52614d378c465 100644 --- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp @@ -427,7 +427,7 @@ Type *SPIRVEmitIntrinsics::reconstructType(Value *Op, bool UnknownElemTypeI8, void SPIRVEmitIntrinsics::buildAssignType(IRBuilder<> &B, Type *Ty, Value *Arg) { - Value *OfType = PoisonValue::get(Ty); + Value *OfType = getNormalizedPoisonValue(Ty); CallInst *AssignCI = nullptr; if (Arg->getType()->isAggregateType() && Ty->isAggregateType() && allowEmitFakeUse(Arg)) { @@ -447,6 +447,7 @@ void SPIRVEmitIntrinsics::buildAssignType(IRBuilder<> &B, Type *Ty, void SPIRVEmitIntrinsics::buildAssignPtr(IRBuilder<> &B, Type *ElemTy, Value *Arg) { + ElemTy = normalizeType(ElemTy); Value *OfType = PoisonValue::get(ElemTy); CallInst *AssignPtrTyCI = GR->findAssignPtrTypeInstr(Arg); if (AssignPtrTyCI == nullptr || @@ -470,7 +471,7 @@ void SPIRVEmitIntrinsics::updateAssignType(CallInst *AssignCI, Value *Arg, return; // update association with the pointee type - Type *ElemTy = OfType->getType(); + Type *ElemTy = normalizeType(OfType->getType()); GR->addDeducedElementType(AssignCI, ElemTy); GR->addDeducedElementType(Arg, ElemTy); } @@ -490,7 +491,7 @@ CallInst *SPIRVEmitIntrinsics::buildSpvPtrcast(Function *F, Value *Op, } Type *OpTy = Op->getType(); SmallVector Types = {OpTy, OpTy}; - SmallVector Args = {Op, buildMD(PoisonValue::get(ElemTy)), + SmallVector Args = {Op, buildMD(getNormalizedPoisonValue(ElemTy)), B.getInt32(getPointerAddressSpace(OpTy))}; CallInst *PtrCasted = B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args); @@ -766,7 +767,7 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper( // remember the found relationship if (Ty && !IgnoreKnownType) { // specify nested types if needed, otherwise return unchanged - GR->addDeducedElementType(I, Ty); + GR->addDeducedElementType(I, normalizeType(Ty)); } return Ty; @@ -852,7 +853,7 @@ Type *SPIRVEmitIntrinsics::deduceNestedTypeHelper( } if (Ty != OpTy) { Type *NewTy = VectorType::get(Ty, VecTy->getElementCount()); - GR->addDeducedCompositeType(U, NewTy); + GR->addDeducedCompositeType(U, normalizeType(NewTy)); return NewTy; } } @@ -990,6 +991,7 @@ bool SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionRet( if (KnownElemTy) return false; if (Type *OpElemTy = GR->findDeducedElementType(Op)) { + OpElemTy = normalizeType(OpElemTy); GR->addDeducedElementType(F, OpElemTy); GR->addReturnType( F, TypedPointerType::get(OpElemTy, @@ -1002,7 +1004,7 @@ bool SPIRVEmitIntrinsics::deduceOperandElementTypeFunctionRet( continue; if (CallInst *AssignCI = GR->findAssignPtrTypeInstr(CI)) { if (Type *PrevElemTy = GR->findDeducedElementType(CI)) { - updateAssignType(AssignCI, CI, PoisonValue::get(OpElemTy)); + updateAssignType(AssignCI, CI, getNormalizedPoisonValue(OpElemTy)); propagateElemType(CI, PrevElemTy, VisitedSubst); } } @@ -1162,11 +1164,11 @@ void SPIRVEmitIntrinsics::deduceOperandElementType( Type *Ty = AskTy ? AskTy : GR->findDeducedElementType(Op); if (Ty == KnownElemTy) continue; - Value *OpTyVal = PoisonValue::get(KnownElemTy); + Value *OpTyVal = getNormalizedPoisonValue(KnownElemTy); Type *OpTy = Op->getType(); if (!Ty || AskTy || isUntypedPointerTy(Ty) || isTodoType(Op)) { Type *PrevElemTy = GR->findDeducedElementType(Op); - GR->addDeducedElementType(Op, KnownElemTy); + GR->addDeducedElementType(Op, normalizeType(KnownElemTy)); // check if KnownElemTy is complete if (!Uncomplete) eraseTodoType(Op); @@ -1492,7 +1494,7 @@ void SPIRVEmitIntrinsics::insertAssignPtrTypeTargetExt( // Our previous guess about the type seems to be wrong, let's update // inferred type according to a new, more precise type information. - updateAssignType(AssignCI, V, PoisonValue::get(AssignedType)); + updateAssignType(AssignCI, V, getNormalizedPoisonValue(AssignedType)); } void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast( @@ -1507,7 +1509,7 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast( return; setInsertPointSkippingPhis(B, I); - Value *ExpectedElementVal = PoisonValue::get(ExpectedElementType); + Value *ExpectedElementVal = getNormalizedPoisonValue(ExpectedElementType); MetadataAsValue *VMD = buildMD(ExpectedElementVal); unsigned AddressSpace = getPointerAddressSpace(Pointer->getType()); bool FirstPtrCastOrAssignPtrType = true; @@ -1653,7 +1655,7 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I, if (!ElemTy) { ElemTy = getPointeeTypeByCallInst(DemangledName, CalledF, OpIdx); if (ElemTy) { - GR->addDeducedElementType(CalledArg, ElemTy); + GR->addDeducedElementType(CalledArg, normalizeType(ElemTy)); } else { for (User *U : CalledArg->users()) { if (Instruction *Inst = dyn_cast(U)) { @@ -1704,6 +1706,11 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I, } Instruction *SPIRVEmitIntrinsics::visitInsertElementInst(InsertElementInst &I) { + // If it's a <1 x Type> vector type, don't modify it. It's not a legal vector + // type in LLT and IRTranslator will replace it by the scalar. + if (isVector1(I.getType())) + return &I; + SmallVector Types = {I.getType(), I.getOperand(0)->getType(), I.getOperand(1)->getType(), I.getOperand(2)->getType()}; @@ -1717,6 +1724,11 @@ Instruction *SPIRVEmitIntrinsics::visitInsertElementInst(InsertElementInst &I) { Instruction * SPIRVEmitIntrinsics::visitExtractElementInst(ExtractElementInst &I) { + // If it's a <1 x Type> vector type, don't modify it. It's not a legal vector + // type in LLT and IRTranslator will replace it by the scalar. + if (isVector1(I.getVectorOperandType())) + return &I; + IRBuilder<> B(I.getParent()); B.SetInsertPoint(&I); SmallVector Types = {I.getType(), I.getVectorOperandType(), @@ -1984,8 +1996,9 @@ void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I, Type *ElemTy = GR->findDeducedElementType(Op); buildAssignPtr(B, ElemTy ? ElemTy : deduceElementType(Op, true), Op); } else { - CallInst *AssignCI = buildIntrWithMD(Intrinsic::spv_assign_type, - {OpTy}, Op, Op, {}, B); + CallInst *AssignCI = + buildIntrWithMD(Intrinsic::spv_assign_type, {OpTy}, + getNormalizedPoisonValue(OpTy), Op, {}, B); GR->addAssignPtrTypeInstr(Op, AssignCI); } } @@ -2034,7 +2047,7 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I, Type *OpTy = Op->getType(); Value *OpTyVal = Op; if (OpTy->isTargetExtTy()) - OpTyVal = PoisonValue::get(OpTy); + OpTyVal = getNormalizedPoisonValue(OpTy); CallInst *NewOp = buildIntrWithMD(Intrinsic::spv_track_constant, {OpTy, OpTyVal->getType()}, Op, OpTyVal, {}, B); @@ -2045,7 +2058,7 @@ void SPIRVEmitIntrinsics::processInstrAfterVisit(Instruction *I, buildAssignPtr(B, IntegerType::getInt8Ty(I->getContext()), NewOp); SmallVector Types = {OpTy, OpTy}; SmallVector Args = { - NewOp, buildMD(PoisonValue::get(OpElemTy)), + NewOp, buildMD(getNormalizedPoisonValue(OpElemTy)), B.getInt32(getPointerAddressSpace(OpTy))}; CallInst *PtrCasted = B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args); @@ -2178,7 +2191,7 @@ void SPIRVEmitIntrinsics::processParamTypes(Function *F, IRBuilder<> &B) { if (!ElemTy && (ElemTy = deduceFunParamElementType(F, OpIdx)) != nullptr) { if (CallInst *AssignCI = GR->findAssignPtrTypeInstr(Arg)) { DenseSet> VisitedSubst; - updateAssignType(AssignCI, Arg, PoisonValue::get(ElemTy)); + updateAssignType(AssignCI, Arg, getNormalizedPoisonValue(ElemTy)); propagateElemType(Arg, IntegerType::getInt8Ty(F->getContext()), VisitedSubst); } else { @@ -2232,7 +2245,7 @@ bool SPIRVEmitIntrinsics::processFunctionPointers(Module &M) { continue; if (II->getIntrinsicID() == Intrinsic::spv_assign_ptr_type || II->getIntrinsicID() == Intrinsic::spv_ptrcast) { - updateAssignType(II, &F, PoisonValue::get(FPElemTy)); + updateAssignType(II, &F, getNormalizedPoisonValue(FPElemTy)); break; } } @@ -2256,7 +2269,7 @@ bool SPIRVEmitIntrinsics::processFunctionPointers(Module &M) { for (Function *F : Worklist) { SmallVector Args; for (const auto &Arg : F->args()) - Args.push_back(PoisonValue::get(Arg.getType())); + Args.push_back(getNormalizedPoisonValue(Arg.getType())); IRB.CreateCall(F, Args); } IRB.CreateRetVoid(); @@ -2286,7 +2299,7 @@ void SPIRVEmitIntrinsics::applyDemangledPtrArgTypes(IRBuilder<> &B) { buildAssignPtr(B, ElemTy, Arg); } } else if (isa(Param)) { - GR->addDeducedElementType(Param, ElemTy); + GR->addDeducedElementType(Param, normalizeType(ElemTy)); // insertAssignTypeIntrs() will complete buildAssignPtr() } else { B.SetInsertPoint(CI->getParent() @@ -2302,6 +2315,7 @@ void SPIRVEmitIntrinsics::applyDemangledPtrArgTypes(IRBuilder<> &B) { if (!RefF || !isPointerTy(RefF->getReturnType()) || GR->findDeducedElementType(RefF)) continue; + ElemTy = normalizeType(ElemTy); GR->addDeducedElementType(RefF, ElemTy); GR->addReturnType( RefF, TypedPointerType::get( diff --git a/llvm/lib/Target/SPIRV/SPIRVUtils.h b/llvm/lib/Target/SPIRV/SPIRVUtils.h index fd48098257065..552adf2df7d17 100644 --- a/llvm/lib/Target/SPIRV/SPIRVUtils.h +++ b/llvm/lib/Target/SPIRV/SPIRVUtils.h @@ -383,6 +383,28 @@ inline const Type *unifyPtrType(const Type *Ty) { return toTypedPointer(const_cast(Ty)); } +inline bool isVector1(Type *Ty) { + auto *FVTy = dyn_cast(Ty); + return FVTy && FVTy->getNumElements() == 1; +} + +// Modify an LLVM type to conform with future transformations in IRTranslator. +// At the moment use cases comprise only a <1 x Type> vector. To extend when/if +// needed. +inline Type *normalizeType(Type *Ty) { + auto *FVTy = dyn_cast(Ty); + if (!FVTy || FVTy->getNumElements() != 1) + return Ty; + // If it's a <1 x Type> vector type, replace it by the element type, because + // it's not a legal vector type in LLT and IRTranslator will represent it as + // the scalar eventually. + return normalizeType(FVTy->getElementType()); +} + +inline PoisonValue *getNormalizedPoisonValue(Type *Ty) { + return PoisonValue::get(normalizeType(Ty)); +} + MachineInstr *getVRegDef(MachineRegisterInfo &MRI, Register Reg); #define SPIRV_BACKEND_SERVICE_FUN_NAME "__spirv_backend_service_fun" diff --git a/llvm/test/CodeGen/SPIRV/validate/triton-tut-softmax-kernel.ll b/llvm/test/CodeGen/SPIRV/validate/triton-tut-softmax-kernel.ll new file mode 100644 index 0000000000000..d8a6c85b3d407 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/validate/triton-tut-softmax-kernel.ll @@ -0,0 +1,221 @@ +; This is an excerpt from the tutorial of the Triton language converted into +; LLVM IR via the Triton XPU backend and cleaned of irrelevant details. +; The only pass criterion is that spirv-val considers output valid. + +; Ths particular case is related to translation of <1 x Ty> vectors. + +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val --target-env spv1.4 %} + +define spir_kernel void @softmax_kernel(ptr addrspace(1) nocapture writeonly %0, ptr addrspace(1) nocapture readonly %1, i32 %2, i32 %3, i32 %4, i32 %5, ptr addrspace(3) nocapture %6) { + %8 = tail call spir_func i64 @_Z12get_group_idj(i32 0) + %9 = trunc i64 %8 to i32 + %10 = tail call spir_func i64 @_Z14get_num_groupsj(i32 0) + %11 = trunc i64 %10 to i32 + %12 = tail call spir_func i64 @_Z12get_local_idj(i32 0) + %13 = trunc i64 %12 to i32 + %14 = and i32 %13, 255 + %15 = or disjoint i32 %14, 256 + %16 = or disjoint i32 %14, 512 + %17 = or disjoint i32 %14, 768 + %18 = icmp slt i32 %14, %5 + %19 = icmp slt i32 %15, %5 + %20 = icmp slt i32 %16, %5 + %21 = icmp slt i32 %17, %5 + %22 = icmp sgt i32 %4, %9 + br i1 %22, label %.lr.ph, label %._crit_edge + +.lr.ph: ; preds = %7 + %23 = lshr i64 %12, 5 + %24 = and i32 %13, 31 + %25 = zext nneg i32 %15 to i64 + %26 = zext nneg i32 %16 to i64 + %27 = zext nneg i32 %17 to i64 + %28 = and i64 %12, 255 + %29 = and i64 %23, 7 + %30 = icmp eq i32 %24, 0 + %31 = getelementptr float, ptr addrspace(3) %6, i64 %29 + %32 = icmp slt i32 %13, 8 + %sext = shl i64 %12, 32 + %33 = ashr exact i64 %sext, 30 + %34 = getelementptr i8, ptr addrspace(3) %6, i64 %33 + %35 = and i32 %13, 7 + %36 = icmp eq i32 %35, 0 + %37 = and i1 %32, %36 + br label %38 + +38: ; preds = %.lr.ph, %123 + %39 = phi i32 [ %9, %.lr.ph ], [ %124, %123 ] + %40 = mul i32 %39, %2 + %41 = sext i32 %40 to i64 + %42 = getelementptr float, ptr addrspace(1) %1, i64 %41 + %43 = getelementptr float, ptr addrspace(1) %42, i64 %25 + %44 = getelementptr float, ptr addrspace(1) %42, i64 %26 + %45 = getelementptr float, ptr addrspace(1) %42, i64 %27 + br i1 %18, label %46, label %49 + +46: ; preds = %38 + %47 = getelementptr float, ptr addrspace(1) %42, i64 %28 + %48 = load <1 x float>, ptr addrspace(1) %47, align 4 + br label %49 + +49: ; preds = %46, %38 + %50 = phi <1 x float> [ %48, %46 ], [ splat (float 0xFFF0000000000000), %38 ] + %51 = extractelement <1 x float> %50, i64 0 + br i1 %19, label %52, label %54 + +52: ; preds = %49 + %53 = load <1 x float>, ptr addrspace(1) %43, align 4 + br label %54 + +54: ; preds = %52, %49 + %55 = phi <1 x float> [ %53, %52 ], [ splat (float 0xFFF0000000000000), %49 ] + %56 = extractelement <1 x float> %55, i64 0 + br i1 %20, label %57, label %59 + +57: ; preds = %54 + %58 = load <1 x float>, ptr addrspace(1) %44, align 4 + br label %59 + +59: ; preds = %57, %54 + %60 = phi <1 x float> [ %58, %57 ], [ splat (float 0xFFF0000000000000), %54 ] + %61 = extractelement <1 x float> %60, i64 0 + br i1 %21, label %62, label %64 + +62: ; preds = %59 + %63 = load <1 x float>, ptr addrspace(1) %45, align 4 + br label %64 + +64: ; preds = %62, %59 + %65 = phi <1 x float> [ %63, %62 ], [ splat (float 0xFFF0000000000000), %59 ] + %66 = extractelement <1 x float> %65, i64 0 + tail call spir_func void @_Z7barrierj(i32 1) + %67 = tail call float @llvm.maxnum.f32(float %51, float %56) + %68 = tail call float @llvm.maxnum.f32(float %67, float %61) + %69 = tail call float @llvm.maxnum.f32(float %68, float %66) + %70 = tail call spir_func float @_Z27__spirv_GroupNonUniformFMaxiif(i32 3, i32 0, float %69) + br i1 %30, label %71, label %72 + +71: ; preds = %64 + store float %70, ptr addrspace(3) %31, align 4 + br label %72 + +72: ; preds = %71, %64 + tail call spir_func void @_Z7barrierj(i32 1) + br i1 %32, label %74, label %.thread1 + +.thread1: ; preds = %72 + %73 = tail call spir_func float @_Z27__spirv_GroupNonUniformFMaxiifj(i32 3, i32 3, float poison, i32 8) + br label %78 + +74: ; preds = %72 + %75 = load float, ptr addrspace(3) %34, align 4 + %76 = tail call spir_func float @_Z27__spirv_GroupNonUniformFMaxiifj(i32 3, i32 3, float %75, i32 8) + br i1 %37, label %77, label %78 + +77: ; preds = %74 + store float %76, ptr addrspace(3) %34, align 4 + br label %78 + +78: ; preds = %.thread1, %77, %74 + tail call spir_func void @_Z7barrierj(i32 1) + %79 = load float, ptr addrspace(3) %6, align 4 + %80 = fsub float %51, %79 + %81 = fsub float %56, %79 + %82 = fsub float %61, %79 + %83 = fsub float %66, %79 + %84 = fmul float %80, 0x3FF7154760000000 + %85 = tail call float @llvm.exp2.f32(float %84) + %86 = fmul float %81, 0x3FF7154760000000 + %87 = tail call float @llvm.exp2.f32(float %86) + %88 = fmul float %82, 0x3FF7154760000000 + %89 = tail call float @llvm.exp2.f32(float %88) + %90 = fmul float %83, 0x3FF7154760000000 + %91 = tail call float @llvm.exp2.f32(float %90) + tail call spir_func void @_Z7barrierj(i32 1) + %92 = fadd float %85, %87 + %93 = fadd float %89, %92 + %94 = fadd float %91, %93 + %95 = tail call spir_func float @_Z27__spirv_GroupNonUniformFAddiif(i32 3, i32 0, float %94) + br i1 %30, label %96, label %97 + +96: ; preds = %78 + store float %95, ptr addrspace(3) %31, align 4 + br label %97 + +97: ; preds = %96, %78 + tail call spir_func void @_Z7barrierj(i32 1) + br i1 %32, label %99, label %.thread + +.thread: ; preds = %97 + %98 = tail call spir_func float @_Z27__spirv_GroupNonUniformFAddiifj(i32 3, i32 3, float poison, i32 8) + br label %103 + +99: ; preds = %97 + %100 = load float, ptr addrspace(3) %34, align 4 + %101 = tail call spir_func float @_Z27__spirv_GroupNonUniformFAddiifj(i32 3, i32 3, float %100, i32 8) + br i1 %37, label %102, label %103 + +102: ; preds = %99 + store float %101, ptr addrspace(3) %34, align 4 + br label %103 + +103: ; preds = %.thread, %102, %99 + tail call spir_func void @_Z7barrierj(i32 1) + %104 = load float, ptr addrspace(3) %6, align 4 + %105 = fdiv float %87, %104 + %106 = fdiv float %89, %104 + %107 = fdiv float %91, %104 + %108 = mul i32 %39, %3 + %109 = sext i32 %108 to i64 + %110 = getelementptr float, ptr addrspace(1) %0, i64 %109 + %111 = getelementptr float, ptr addrspace(1) %110, i64 %25 + %112 = getelementptr float, ptr addrspace(1) %110, i64 %26 + %113 = getelementptr float, ptr addrspace(1) %110, i64 %27 + br i1 %18, label %114, label %117 + +114: ; preds = %103 + %115 = fdiv float %85, %104 + %116 = getelementptr float, ptr addrspace(1) %110, i64 %28 + store float %115, ptr addrspace(1) %116, align 4 + br label %117 + +117: ; preds = %114, %103 + br i1 %19, label %118, label %119 + +118: ; preds = %117 + store float %105, ptr addrspace(1) %111, align 4 + br label %119 + +119: ; preds = %118, %117 + br i1 %20, label %120, label %121 + +120: ; preds = %119 + store float %106, ptr addrspace(1) %112, align 4 + br label %121 + +121: ; preds = %120, %119 + br i1 %21, label %122, label %123 + +122: ; preds = %121 + store float %107, ptr addrspace(1) %113, align 4 + br label %123 + +123: ; preds = %122, %121 + %124 = add i32 %39, %11 + %125 = icmp slt i32 %124, %4 + br i1 %125, label %38, label %._crit_edge + +._crit_edge: ; preds = %123, %7 + ret void +} + +declare float @llvm.maxnum.f32(float, float) +declare spir_func float @_Z27__spirv_GroupNonUniformFAddiifj(i32, i32, float, i32) +declare spir_func float @_Z27__spirv_GroupNonUniformFAddiif(i32, i32, float) +declare spir_func float @_Z27__spirv_GroupNonUniformFMaxiifj(i32, i32, float, i32) +declare spir_func float @_Z27__spirv_GroupNonUniformFMaxiif(i32, i32, float) +declare spir_func void @_Z7barrierj(i32) +declare spir_func i64 @_Z12get_local_idj(i32) +declare spir_func i64 @_Z14get_num_groupsj(i32) +declare spir_func i64 @_Z12get_group_idj(i32) +declare float @llvm.exp2.f32(float)