diff --git a/include/dxc/dxcapi.internal.h b/include/dxc/dxcapi.internal.h index 46a485206e..780b35ced9 100644 --- a/include/dxc/dxcapi.internal.h +++ b/include/dxc/dxcapi.internal.h @@ -136,7 +136,8 @@ enum LEGAL_INTRINSIC_COMPTYPES { #ifdef ENABLE_SPIRV_CODEGEN LICOMPTYPE_VK_BUFFER_POINTER = 54, - LICOMPTYPE_COUNT = 55 + LICOMPTYPE_VK_SAMPLED_TEXTURE2D = 55, + LICOMPTYPE_COUNT = 56 #else LICOMPTYPE_COUNT = 54 #endif diff --git a/tools/clang/include/clang/AST/HlslTypes.h b/tools/clang/include/clang/AST/HlslTypes.h index 43c1effdb8..b681d6f979 100644 --- a/tools/clang/include/clang/AST/HlslTypes.h +++ b/tools/clang/include/clang/AST/HlslTypes.h @@ -407,6 +407,11 @@ clang::CXXRecordDecl * DeclareVkBufferPointerType(clang::ASTContext &context, clang::DeclContext *declContext); +clang::CXXRecordDecl *DeclareVkSampledTexture2DType( + clang::ASTContext &context, clang::DeclContext *declContext, + clang::QualType float2Type, clang::QualType int2Type, + clang::QualType float4Type); + clang::CXXRecordDecl *DeclareInlineSpirvType(clang::ASTContext &context, clang::DeclContext *declContext, llvm::StringRef typeName, diff --git a/tools/clang/include/clang/SPIRV/AstTypeProbe.h b/tools/clang/include/clang/SPIRV/AstTypeProbe.h index 45bff1bad4..1479075f12 100644 --- a/tools/clang/include/clang/SPIRV/AstTypeProbe.h +++ b/tools/clang/include/clang/SPIRV/AstTypeProbe.h @@ -257,6 +257,9 @@ bool isTexture(QualType); /// Texture2DMSArray type. bool isTextureMS(QualType); +/// \brief Returns true if the given type is an HLSL SampledTexture type. +bool isSampledTexture(QualType); + /// \brief Returns true if the given type is an HLSL RWTexture type. bool isRWTexture(QualType); diff --git a/tools/clang/include/clang/SPIRV/SpirvBuilder.h b/tools/clang/include/clang/SPIRV/SpirvBuilder.h index 1d012568d6..75ad453d25 100644 --- a/tools/clang/include/clang/SPIRV/SpirvBuilder.h +++ b/tools/clang/include/clang/SPIRV/SpirvBuilder.h @@ -286,6 +286,9 @@ class SpirvBuilder { /// If compareVal is given a non-zero value, *Dref* variants of OpImageSample* /// will be generated. /// + /// If imageType is not a sampled image type, the OpSampledImage* instructions + /// will be generated. + /// /// If lod or grad is given a non-zero value, *ExplicitLod variants of /// OpImageSample* will be generated; otherwise, *ImplicitLod variant will /// be generated. @@ -334,6 +337,8 @@ class SpirvBuilder { /// \brief Creates SPIR-V instructions for gathering the given image. /// + /// If imageType is not a sampled image type, the OpSampledImage* instructions + /// will be generated. /// If compareVal is given a non-null value, OpImageDrefGather or /// OpImageSparseDrefGather will be generated; otherwise, OpImageGather or /// OpImageSparseGather will be generated. diff --git a/tools/clang/lib/AST/ASTContextHLSL.cpp b/tools/clang/lib/AST/ASTContextHLSL.cpp index 913b28ced8..2ece3fdd0a 100644 --- a/tools/clang/lib/AST/ASTContextHLSL.cpp +++ b/tools/clang/lib/AST/ASTContextHLSL.cpp @@ -1369,6 +1369,130 @@ CXXRecordDecl *hlsl::DeclareNodeOrRecordType( } #ifdef ENABLE_SPIRV_CODEGEN +CXXRecordDecl *hlsl::DeclareVkSampledTexture2DType(ASTContext &context, + DeclContext *declContext, + QualType float2Type, + QualType int2Type, + QualType float4Type) { + BuiltinTypeDeclBuilder Builder(declContext, "SampledTexture2D", + TagDecl::TagKind::TTK_Struct); + + QualType defaultTextureType = float4Type; + TemplateTypeParmDecl *TyParamDecl = + Builder.addTypeTemplateParam("SampledTextureType", defaultTextureType); + + Builder.startDefinition(); + + QualType paramType = QualType(TyParamDecl->getTypeForDecl(), 0); + CXXRecordDecl *recordDecl = Builder.getRecordDecl(); + + QualType floatType = context.FloatTy; + QualType uintType = context.UnsignedIntTy; + // Add Sample method + // Sample(location) + CXXMethodDecl *sampleDecl = CreateObjectFunctionDeclarationWithParams( + context, recordDecl, paramType, ArrayRef(float2Type), + ArrayRef(StringRef("location")), + context.DeclarationNames.getIdentifier(&context.Idents.get("Sample")), + /*isConst*/ true); + sampleDecl->addAttr(HLSLIntrinsicAttr::CreateImplicit( + context, "op", "", static_cast(hlsl::IntrinsicOp::MOP_Sample))); + sampleDecl->addAttr(HLSLCXXOverloadAttr::CreateImplicit(context)); + + // Sample(location, offset) + QualType params2[] = {float2Type, int2Type}; + StringRef names2[] = {"location", "offset"}; + CXXMethodDecl *sampleDecl2 = CreateObjectFunctionDeclarationWithParams( + context, recordDecl, paramType, params2, names2, + context.DeclarationNames.getIdentifier(&context.Idents.get("Sample")), + /*isConst*/ true); + sampleDecl2->addAttr(HLSLIntrinsicAttr::CreateImplicit( + context, "op", "", static_cast(hlsl::IntrinsicOp::MOP_Sample))); + sampleDecl2->addAttr(HLSLCXXOverloadAttr::CreateImplicit(context)); + + // Sample(location, offset, clamp) + QualType params3[] = {float2Type, int2Type, floatType}; + StringRef names3[] = {"location", "offset", "clamp"}; + CXXMethodDecl *sampleDecl3 = CreateObjectFunctionDeclarationWithParams( + context, recordDecl, paramType, params3, names3, + context.DeclarationNames.getIdentifier(&context.Idents.get("Sample")), + /*isConst*/ true); + sampleDecl3->addAttr(HLSLIntrinsicAttr::CreateImplicit( + context, "op", "", static_cast(hlsl::IntrinsicOp::MOP_Sample))); + sampleDecl3->addAttr(HLSLCXXOverloadAttr::CreateImplicit(context)); + + // Sample(location, offset, clamp, status) + QualType params4[] = {float2Type, int2Type, floatType, + context.getLValueReferenceType(uintType)}; + StringRef names4[] = {"location", "offset", "clamp", "status"}; + CXXMethodDecl *sampleDecl4 = CreateObjectFunctionDeclarationWithParams( + context, recordDecl, paramType, params4, names4, + context.DeclarationNames.getIdentifier(&context.Idents.get("Sample")), + /*isConst*/ true); + sampleDecl4->addAttr(HLSLIntrinsicAttr::CreateImplicit( + context, "op", "", static_cast(hlsl::IntrinsicOp::MOP_Sample))); + sampleDecl4->addAttr(HLSLCXXOverloadAttr::CreateImplicit(context)); + + // CalculateLevelOfDetail(location) + CXXMethodDecl *lodDecl = CreateObjectFunctionDeclarationWithParams( + context, recordDecl, floatType, ArrayRef(float2Type), + ArrayRef(StringRef("location")), + context.DeclarationNames.getIdentifier( + &context.Idents.get("CalculateLevelOfDetail")), + /*isConst*/ true); + lodDecl->addAttr(HLSLIntrinsicAttr::CreateImplicit( + context, "op", "", + static_cast(hlsl::IntrinsicOp::MOP_CalculateLevelOfDetail))); + lodDecl->addAttr(HLSLCXXOverloadAttr::CreateImplicit(context)); + + // CalculateLevelOfDetailUnclamped(location) + CXXMethodDecl *lodUnclampedDecl = CreateObjectFunctionDeclarationWithParams( + context, recordDecl, floatType, ArrayRef(float2Type), + ArrayRef(StringRef("location")), + context.DeclarationNames.getIdentifier( + &context.Idents.get("CalculateLevelOfDetailUnclamped")), + /*isConst*/ true); + lodUnclampedDecl->addAttr(HLSLIntrinsicAttr::CreateImplicit( + context, "op", "", + static_cast( + hlsl::IntrinsicOp::MOP_CalculateLevelOfDetailUnclamped))); + lodUnclampedDecl->addAttr(HLSLCXXOverloadAttr::CreateImplicit(context)); + + // Gather(location) + CXXMethodDecl *gatherDecl = CreateObjectFunctionDeclarationWithParams( + context, recordDecl, paramType, ArrayRef(float2Type), + ArrayRef(StringRef("location")), + context.DeclarationNames.getIdentifier(&context.Idents.get("Gather")), + /*isConst*/ true); + gatherDecl->addAttr(HLSLIntrinsicAttr::CreateImplicit( + context, "op", "", static_cast(hlsl::IntrinsicOp::MOP_Gather))); + gatherDecl->addAttr(HLSLCXXOverloadAttr::CreateImplicit(context)); + // Gather(location, offset) + QualType gatherParams2[] = {float2Type, int2Type}; + StringRef gatherNames2[] = {"location", "offset"}; + CXXMethodDecl *gatherDecl2 = CreateObjectFunctionDeclarationWithParams( + context, recordDecl, paramType, gatherParams2, gatherNames2, + context.DeclarationNames.getIdentifier(&context.Idents.get("Gather")), + /*isConst*/ true); + gatherDecl2->addAttr(HLSLIntrinsicAttr::CreateImplicit( + context, "op", "", static_cast(hlsl::IntrinsicOp::MOP_Gather))); + gatherDecl2->addAttr(HLSLCXXOverloadAttr::CreateImplicit(context)); + // Gather(location, offset, status) + QualType gatherParams3[] = {float2Type, int2Type, + context.getLValueReferenceType(uintType)}; + StringRef gatherNames3[] = {"location", "offset", "status"}; + CXXMethodDecl *gatherDecl3 = CreateObjectFunctionDeclarationWithParams( + context, recordDecl, paramType, gatherParams3, gatherNames3, + context.DeclarationNames.getIdentifier(&context.Idents.get("Gather")), + /*isConst*/ true); + gatherDecl3->addAttr(HLSLIntrinsicAttr::CreateImplicit( + context, "op", "", static_cast(hlsl::IntrinsicOp::MOP_Gather))); + gatherDecl3->addAttr(HLSLCXXOverloadAttr::CreateImplicit(context)); + + Builder.completeDefinition(); + return recordDecl; +} + CXXRecordDecl *hlsl::DeclareVkBufferPointerType(ASTContext &context, DeclContext *declContext) { BuiltinTypeDeclBuilder Builder(declContext, "BufferPointer", diff --git a/tools/clang/lib/SPIRV/AstTypeProbe.cpp b/tools/clang/lib/SPIRV/AstTypeProbe.cpp index fda9a3ab3e..48c3012501 100644 --- a/tools/clang/lib/SPIRV/AstTypeProbe.cpp +++ b/tools/clang/lib/SPIRV/AstTypeProbe.cpp @@ -926,6 +926,17 @@ bool isTexture(QualType type) { return false; } +bool isSampledTexture(QualType type) { + if (const auto *rt = type->getAs()) { + const auto name = rt->getDecl()->getName(); + // TODO(https://github.com/microsoft/DirectXShaderCompiler/issues/7979): Add + // other sampled texture types as needed. + if (name == "SampledTexture2D") + return true; + } + return false; +} + bool isTextureMS(QualType type) { if (const auto *rt = type->getAs()) { const auto name = rt->getDecl()->getName(); diff --git a/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp b/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp index b660ea70df..d9ddc3428b 100644 --- a/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp +++ b/tools/clang/lib/SPIRV/LowerTypeVisitor.cpp @@ -849,6 +849,22 @@ const SpirvType *LowerTypeVisitor::lowerVkTypeInVkNamespace( assert(visitedTypeStack.size() == visitedTypeStackSize); return pointerType; } + if (name == "SampledTexture2D") { + const auto sampledType = hlsl::GetHLSLResourceResultType(type); + auto loweredType = lowerType(getElementType(astContext, sampledType), rule, + /*isRowMajor*/ llvm::None, srcLoc); + + // Treat bool textures as uint for compatibility with OpTypeImage. + if (loweredType == spvContext.getBoolType()) { + loweredType = spvContext.getUIntType(32); + } + + const auto *imageType = spvContext.getImageType( + loweredType, spv::Dim::Dim2D, ImageType::WithDepth::No, + false /* array */, false /* ms */, ImageType::WithSampler::Yes, + spv::ImageFormat::Unknown); + return spvContext.getSampledImageType(imageType); + } emitError("unknown type %0 in vk namespace", srcLoc) << type; return nullptr; } diff --git a/tools/clang/lib/SPIRV/SpirvBuilder.cpp b/tools/clang/lib/SPIRV/SpirvBuilder.cpp index 86701f48fd..319436749f 100644 --- a/tools/clang/lib/SPIRV/SpirvBuilder.cpp +++ b/tools/clang/lib/SPIRV/SpirvBuilder.cpp @@ -620,8 +620,15 @@ SpirvInstruction *SpirvBuilder::createImageSample( assert(lod == nullptr || minLod == nullptr); // An OpSampledImage is required to do the image sampling. - auto *sampledImage = - createSampledImage(imageType, image, sampler, loc, range); + // Skip creating OpSampledImage if the imageType is a sampled texture. + SpirvInstruction *sampledImage = nullptr; + if (isSampledTexture(imageType)) { + assert(!sampler && + "sampler must be null when sampling from a sampled texture"); + sampledImage = image; + } else { + sampledImage = createSampledImage(imageType, image, sampler, loc, range); + } const auto mask = composeImageOperandsMask( bias, lod, grad, constOffset, varOffset, constOffsets, sample, minLod); @@ -707,8 +714,15 @@ SpirvInstruction *SpirvBuilder::createImageGather( assert(insertPoint && "null insert point"); // An OpSampledImage is required to do the image sampling. - auto *sampledImage = - createSampledImage(imageType, image, sampler, loc, range); + // Skip creating OpSampledImage if the imageType is a sampled texture. + SpirvInstruction *sampledImage = nullptr; + if (isSampledTexture(imageType)) { + assert(!sampler && + "sampler must be null when sampling from a sampled texture"); + sampledImage = image; + } else { + sampledImage = createSampledImage(imageType, image, sampler, loc, range); + } // TODO: Update ImageGather to accept minLod if necessary. const auto mask = composeImageOperandsMask( diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index 2c8b8a3440..4b10291fbd 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -4451,15 +4451,25 @@ SpirvEmitter::processTextureLevelOfDetail(const CXXMemberCallExpr *expr, // Texture2D(Array).CalculateLevelOfDetail(SamplerState S, float2 xy); // TextureCube(Array).CalculateLevelOfDetail(SamplerState S, float3 xyz); // Texture3D.CalculateLevelOfDetail(SamplerState S, float3 xyz); + // SampledTexture2D.CalculateLevelOfDetail(float2 xy); // Return type is always a single float (LOD). - assert(expr->getNumArgs() == 2u); - const auto *object = expr->getImplicitObjectArgument(); - auto *objectInfo = loadIfGLValue(object); - auto *samplerState = doExpr(expr->getArg(0)); - auto *coordinate = doExpr(expr->getArg(1)); - auto *sampledImage = spvBuilder.createSampledImage( - object->getType(), objectInfo, samplerState, expr->getExprLoc()); + const auto *imageExpr = expr->getImplicitObjectArgument(); + const QualType imageType = imageExpr->getType(); + // numarg is 1 if isSampledTexture(imageType). otherwise 2. + assert(expr->getNumArgs() == (isSampledTexture(imageType) ? 1u : 2u)); + + auto *objectInfo = loadIfGLValue(imageExpr); + auto *samplerState = + isSampledTexture(imageType) ? nullptr : doExpr(expr->getArg(0)); + auto *coordinate = isSampledTexture(imageType) ? doExpr(expr->getArg(0)) + : doExpr(expr->getArg(1)); + + auto *sampledImage = + isSampledTexture(imageType) + ? objectInfo + : spvBuilder.createSampledImage(imageExpr->getType(), objectInfo, + samplerState, expr->getExprLoc()); // The result type of OpImageQueryLod must be a float2. const QualType queryResultType = @@ -5813,10 +5823,70 @@ SpirvEmitter::processTextureSampleGather(const CXXMemberCallExpr *expr, // [, uint Status]); // // Other Texture types do not have a Gather method. + // + // For SampledTexture2D: + // DXGI_FORMAT Object.Sample(float Location + // [, int Offset] + // [, float Clamp] + // [, out uint Status]); + // + // For SampledTexture2D: + //