Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/dxc/dxcapi.internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ enum LEGAL_INTRINSIC_COMPTYPES {

#ifdef ENABLE_SPIRV_CODEGEN
LICOMPTYPE_VK_BUFFER_POINTER = 54,
LICOMPTYPE_COUNT = 55
LICOMPTYPE_VK_SAMPLED_TEXTURE2D = 55,
LICOMPTYPE_COUNT = 56
#else
LICOMPTYPE_COUNT = 54
#endif
Expand Down
5 changes: 5 additions & 0 deletions tools/clang/include/clang/AST/HlslTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,11 @@ clang::CXXRecordDecl *
DeclareVkBufferPointerType(clang::ASTContext &context,
clang::DeclContext *declContext);

clang::CXXRecordDecl *DeclareVkSampledTexture2DType(
clang::ASTContext &context, clang::DeclContext *declContext,
clang::QualType float2Type, clang::QualType int2Type,
clang::QualType float4Type);

clang::CXXRecordDecl *DeclareInlineSpirvType(clang::ASTContext &context,
clang::DeclContext *declContext,
llvm::StringRef typeName,
Expand Down
3 changes: 3 additions & 0 deletions tools/clang/include/clang/SPIRV/AstTypeProbe.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,9 @@ bool isTexture(QualType);
/// Texture2DMSArray type.
bool isTextureMS(QualType);

/// \brief Returns true if the given type is an HLSL SampledTexture type.
bool isSampledTexture(QualType);

/// \brief Returns true if the given type is an HLSL RWTexture type.
bool isRWTexture(QualType);

Expand Down
5 changes: 5 additions & 0 deletions tools/clang/include/clang/SPIRV/SpirvBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,9 @@ class SpirvBuilder {
/// If compareVal is given a non-zero value, *Dref* variants of OpImageSample*
/// will be generated.
///
/// If imageType is not a sampled image type, the OpSampledImage* instructions
/// will be generated.
///
/// If lod or grad is given a non-zero value, *ExplicitLod variants of
/// OpImageSample* will be generated; otherwise, *ImplicitLod variant will
/// be generated.
Expand Down Expand Up @@ -334,6 +337,8 @@ class SpirvBuilder {

/// \brief Creates SPIR-V instructions for gathering the given image.
///
/// If imageType is not a sampled image type, the OpSampledImage* instructions
/// will be generated.
/// If compareVal is given a non-null value, OpImageDrefGather or
/// OpImageSparseDrefGather will be generated; otherwise, OpImageGather or
/// OpImageSparseGather will be generated.
Expand Down
221 changes: 221 additions & 0 deletions tools/clang/lib/AST/ASTContextHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1369,6 +1369,227 @@ CXXRecordDecl *hlsl::DeclareNodeOrRecordType(
}

#ifdef ENABLE_SPIRV_CODEGEN
CXXRecordDecl *hlsl::DeclareVkSampledTexture2DType(ASTContext &context,
DeclContext *declContext,
QualType float2Type,
QualType int2Type,
QualType float4Type) {
BuiltinTypeDeclBuilder Builder(declContext, "SampledTexture2D",
TagDecl::TagKind::TTK_Struct);

QualType defaultTextureType = float4Type;
TemplateTypeParmDecl *TyParamDecl =
Builder.addTypeTemplateParam("SampledTextureType", defaultTextureType);

Builder.startDefinition();

QualType paramType = QualType(TyParamDecl->getTypeForDecl(), 0);
CXXRecordDecl *recordDecl = Builder.getRecordDecl();

QualType floatType = context.FloatTy;
QualType uintType = context.UnsignedIntTy;
QualType intType = context.IntTy;

// Add Sample method
// Sample(location)
CXXMethodDecl *sampleDecl = CreateObjectFunctionDeclarationWithParams(
context, recordDecl, paramType, ArrayRef<QualType>(float2Type),
ArrayRef<StringRef>(StringRef("location")),
context.DeclarationNames.getIdentifier(&context.Idents.get("Sample")),
/*isConst*/ true);
sampleDecl->addAttr(HLSLIntrinsicAttr::CreateImplicit(
context, "op", "", static_cast<int>(hlsl::IntrinsicOp::MOP_Sample)));
sampleDecl->addAttr(HLSLCXXOverloadAttr::CreateImplicit(context));

// Sample(location, offset)
QualType params2[] = {float2Type, int2Type};
StringRef names2[] = {"location", "offset"};
CXXMethodDecl *sampleDecl2 = CreateObjectFunctionDeclarationWithParams(
context, recordDecl, paramType, params2, names2,
context.DeclarationNames.getIdentifier(&context.Idents.get("Sample")),
/*isConst*/ true);
sampleDecl2->addAttr(HLSLIntrinsicAttr::CreateImplicit(
context, "op", "", static_cast<int>(hlsl::IntrinsicOp::MOP_Sample)));
sampleDecl2->addAttr(HLSLCXXOverloadAttr::CreateImplicit(context));

// Sample(location, offset, clamp)
QualType params3[] = {float2Type, int2Type, floatType};
StringRef names3[] = {"location", "offset", "clamp"};
CXXMethodDecl *sampleDecl3 = CreateObjectFunctionDeclarationWithParams(
context, recordDecl, paramType, params3, names3,
context.DeclarationNames.getIdentifier(&context.Idents.get("Sample")),
/*isConst*/ true);
sampleDecl3->addAttr(HLSLIntrinsicAttr::CreateImplicit(
context, "op", "", static_cast<int>(hlsl::IntrinsicOp::MOP_Sample)));
sampleDecl3->addAttr(HLSLCXXOverloadAttr::CreateImplicit(context));

// Sample(location, offset, clamp, status)
QualType params4[] = {float2Type, int2Type, floatType,
context.getLValueReferenceType(uintType)};
StringRef names4[] = {"location", "offset", "clamp", "status"};
CXXMethodDecl *sampleDecl4 = CreateObjectFunctionDeclarationWithParams(
context, recordDecl, paramType, params4, names4,
context.DeclarationNames.getIdentifier(&context.Idents.get("Sample")),
/*isConst*/ true);
sampleDecl4->addAttr(HLSLIntrinsicAttr::CreateImplicit(
context, "op", "", static_cast<int>(hlsl::IntrinsicOp::MOP_Sample)));
sampleDecl4->addAttr(HLSLCXXOverloadAttr::CreateImplicit(context));

// CalculateLevelOfDetail(location)
CXXMethodDecl *lodDecl = CreateObjectFunctionDeclarationWithParams(
context, recordDecl, floatType, ArrayRef<QualType>(float2Type),
ArrayRef<StringRef>(StringRef("location")),
context.DeclarationNames.getIdentifier(
&context.Idents.get("CalculateLevelOfDetail")),
/*isConst*/ true);
lodDecl->addAttr(HLSLIntrinsicAttr::CreateImplicit(
context, "op", "",
static_cast<int>(hlsl::IntrinsicOp::MOP_CalculateLevelOfDetail)));
lodDecl->addAttr(HLSLCXXOverloadAttr::CreateImplicit(context));

// CalculateLevelOfDetailUnclamped(location)
CXXMethodDecl *lodUnclampedDecl = CreateObjectFunctionDeclarationWithParams(
context, recordDecl, floatType, ArrayRef<QualType>(float2Type),
ArrayRef<StringRef>(StringRef("location")),
context.DeclarationNames.getIdentifier(
&context.Idents.get("CalculateLevelOfDetailUnclamped")),
/*isConst*/ true);
lodUnclampedDecl->addAttr(HLSLIntrinsicAttr::CreateImplicit(
context, "op", "",
static_cast<int>(
hlsl::IntrinsicOp::MOP_CalculateLevelOfDetailUnclamped)));
lodUnclampedDecl->addAttr(HLSLCXXOverloadAttr::CreateImplicit(context));

// Gather(location)
CXXMethodDecl *gatherDecl = CreateObjectFunctionDeclarationWithParams(
context, recordDecl, paramType, ArrayRef<QualType>(float2Type),
ArrayRef<StringRef>(StringRef("location")),
context.DeclarationNames.getIdentifier(&context.Idents.get("Gather")),
/*isConst*/ true);
gatherDecl->addAttr(HLSLIntrinsicAttr::CreateImplicit(
context, "op", "", static_cast<int>(hlsl::IntrinsicOp::MOP_Gather)));
gatherDecl->addAttr(HLSLCXXOverloadAttr::CreateImplicit(context));
// Gather(location, offset)
QualType gatherParams2[] = {float2Type, int2Type};
StringRef gatherNames2[] = {"location", "offset"};
CXXMethodDecl *gatherDecl2 = CreateObjectFunctionDeclarationWithParams(
context, recordDecl, paramType, gatherParams2, gatherNames2,
context.DeclarationNames.getIdentifier(&context.Idents.get("Gather")),
/*isConst*/ true);
gatherDecl2->addAttr(HLSLIntrinsicAttr::CreateImplicit(
context, "op", "", static_cast<int>(hlsl::IntrinsicOp::MOP_Gather)));
gatherDecl2->addAttr(HLSLCXXOverloadAttr::CreateImplicit(context));
// Gather(location, offset, status)
QualType gatherParams3[] = {float2Type, int2Type,
context.getLValueReferenceType(uintType)};
StringRef gatherNames3[] = {"location", "offset", "status"};
CXXMethodDecl *gatherDecl3 = CreateObjectFunctionDeclarationWithParams(
context, recordDecl, paramType, gatherParams3, gatherNames3,
context.DeclarationNames.getIdentifier(&context.Idents.get("Gather")),
/*isConst*/ true);
gatherDecl3->addAttr(HLSLIntrinsicAttr::CreateImplicit(
context, "op", "", static_cast<int>(hlsl::IntrinsicOp::MOP_Gather)));
gatherDecl3->addAttr(HLSLCXXOverloadAttr::CreateImplicit(context));

// GetDimensions(width, height)
QualType getDimensionsParams2[] = {context.getLValueReferenceType(uintType),
context.getLValueReferenceType(uintType)};
StringRef getDimensionsNames2[] = {"width", "height"};
CXXMethodDecl *getDimensionsDecl2 = CreateObjectFunctionDeclarationWithParams(
context, recordDecl, context.VoidTy, getDimensionsParams2,
getDimensionsNames2,
context.DeclarationNames.getIdentifier(
&context.Idents.get("GetDimensions")),
/*isConst*/ true);
getDimensionsDecl2->addAttr(HLSLIntrinsicAttr::CreateImplicit(
context, "op", "",
static_cast<int>(hlsl::IntrinsicOp::MOP_GetDimensions)));
// GetDimensions(width, height) float version
QualType getDimensionsParams2Float[] = {
context.getLValueReferenceType(floatType),
context.getLValueReferenceType(floatType)};
StringRef getDimensionsNames2Float[] = {"width", "height"};
CXXMethodDecl *getDimensionsDecl2Float =
CreateObjectFunctionDeclarationWithParams(
context, recordDecl, context.VoidTy, getDimensionsParams2Float,
getDimensionsNames2Float,
context.DeclarationNames.getIdentifier(
&context.Idents.get("GetDimensions")),
/*isConst*/ true);
getDimensionsDecl2Float->addAttr(HLSLIntrinsicAttr::CreateImplicit(
context, "op", "",
static_cast<int>(hlsl::IntrinsicOp::MOP_GetDimensions)));
// GetDimensions(width, height) int version
QualType getDimensionsParams2Int[] = {
context.getLValueReferenceType(intType),
context.getLValueReferenceType(intType)};
StringRef getDimensionsNames2Int[] = {"width", "height"};
CXXMethodDecl *getDimensionsDecl2Int =
CreateObjectFunctionDeclarationWithParams(
context, recordDecl, context.VoidTy, getDimensionsParams2Int,
getDimensionsNames2Int,
context.DeclarationNames.getIdentifier(
&context.Idents.get("GetDimensions")),
/*isConst*/ true);
getDimensionsDecl2Int->addAttr(HLSLIntrinsicAttr::CreateImplicit(
context, "op", "",
static_cast<int>(hlsl::IntrinsicOp::MOP_GetDimensions)));

// GetDimensions(mipLevel, width, height, numLevels)
QualType getDimensionsParams4[] = {uintType,
context.getLValueReferenceType(uintType),
context.getLValueReferenceType(uintType),
context.getLValueReferenceType(uintType)};
StringRef getDimensionsNames4[] = {"mipLevel", "width", "height",
"numLevels"};
CXXMethodDecl *getDimensionsDecl4 = CreateObjectFunctionDeclarationWithParams(
context, recordDecl, context.VoidTy, getDimensionsParams4,
getDimensionsNames4,
context.DeclarationNames.getIdentifier(
&context.Idents.get("GetDimensions")),
/*isConst*/ true);
getDimensionsDecl4->addAttr(HLSLIntrinsicAttr::CreateImplicit(
context, "op", "",
static_cast<int>(hlsl::IntrinsicOp::MOP_GetDimensions)));
// GetDimensions(mipLevel, width, height, numLevels) float version
QualType getDimensionsParams4Float[] = {
uintType, context.getLValueReferenceType(floatType),
context.getLValueReferenceType(floatType),
context.getLValueReferenceType(floatType)};
StringRef getDimensionsNames4Float[] = {"mipLevel", "width", "height",
"numLevels"};
CXXMethodDecl *getDimensionsDecl4Float =
CreateObjectFunctionDeclarationWithParams(
context, recordDecl, context.VoidTy, getDimensionsParams4Float,
getDimensionsNames4Float,
context.DeclarationNames.getIdentifier(
&context.Idents.get("GetDimensions")),
/*isConst*/ true);
getDimensionsDecl4Float->addAttr(HLSLIntrinsicAttr::CreateImplicit(
context, "op", "",
static_cast<int>(hlsl::IntrinsicOp::MOP_GetDimensions)));
// GetDimensions(mipLevel, width, height, numLevels) int version
QualType getDimensionsParams4Int[] = {
uintType, context.getLValueReferenceType(intType),
context.getLValueReferenceType(intType),
context.getLValueReferenceType(intType)};
StringRef getDimensionsNames4Int[] = {"mipLevel", "width", "height",
"numLevels"};
CXXMethodDecl *getDimensionsDecl4Int =
CreateObjectFunctionDeclarationWithParams(
context, recordDecl, context.VoidTy, getDimensionsParams4Int,
getDimensionsNames4Int,
context.DeclarationNames.getIdentifier(
&context.Idents.get("GetDimensions")),
/*isConst*/ true);
getDimensionsDecl4Int->addAttr(HLSLIntrinsicAttr::CreateImplicit(
context, "op", "",
static_cast<int>(hlsl::IntrinsicOp::MOP_GetDimensions)));

Builder.completeDefinition();
return recordDecl;
}

CXXRecordDecl *hlsl::DeclareVkBufferPointerType(ASTContext &context,
DeclContext *declContext) {
BuiltinTypeDeclBuilder Builder(declContext, "BufferPointer",
Expand Down
11 changes: 11 additions & 0 deletions tools/clang/lib/SPIRV/AstTypeProbe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,17 @@ bool isTexture(QualType type) {
return false;
}

bool isSampledTexture(QualType type) {
if (const auto *rt = type->getAs<RecordType>()) {
const auto name = rt->getDecl()->getName();
// TODO(https://github.com/microsoft/DirectXShaderCompiler/issues/7979): Add
// other sampled texture types as needed.
if (name == "SampledTexture2D")
return true;
}
return false;
}

bool isTextureMS(QualType type) {
if (const auto *rt = type->getAs<RecordType>()) {
const auto name = rt->getDecl()->getName();
Expand Down
16 changes: 16 additions & 0 deletions tools/clang/lib/SPIRV/LowerTypeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,22 @@ const SpirvType *LowerTypeVisitor::lowerVkTypeInVkNamespace(
assert(visitedTypeStack.size() == visitedTypeStackSize);
return pointerType;
}
if (name == "SampledTexture2D") {
const auto sampledType = hlsl::GetHLSLResourceResultType(type);
auto loweredType = lowerType(getElementType(astContext, sampledType), rule,
/*isRowMajor*/ llvm::None, srcLoc);

// Treat bool textures as uint for compatibility with OpTypeImage.
if (loweredType == spvContext.getBoolType()) {
loweredType = spvContext.getUIntType(32);
}

const auto *imageType = spvContext.getImageType(
loweredType, spv::Dim::Dim2D, ImageType::WithDepth::No,
false /* array */, false /* ms */, ImageType::WithSampler::Yes,
spv::ImageFormat::Unknown);
return spvContext.getSampledImageType(imageType);
}
emitError("unknown type %0 in vk namespace", srcLoc) << type;
return nullptr;
}
Expand Down
22 changes: 18 additions & 4 deletions tools/clang/lib/SPIRV/SpirvBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -620,8 +620,15 @@ SpirvInstruction *SpirvBuilder::createImageSample(
assert(lod == nullptr || minLod == nullptr);

// An OpSampledImage is required to do the image sampling.
auto *sampledImage =
createSampledImage(imageType, image, sampler, loc, range);
// Skip creating OpSampledImage if the imageType is a sampled texture.
SpirvInstruction *sampledImage = nullptr;
if (isSampledTexture(imageType)) {
assert(!sampler &&
"sampler must be null when sampling from a sampled texture");
sampledImage = image;
} else {
sampledImage = createSampledImage(imageType, image, sampler, loc, range);
}

const auto mask = composeImageOperandsMask(
bias, lod, grad, constOffset, varOffset, constOffsets, sample, minLod);
Expand Down Expand Up @@ -707,8 +714,15 @@ SpirvInstruction *SpirvBuilder::createImageGather(
assert(insertPoint && "null insert point");

// An OpSampledImage is required to do the image sampling.
auto *sampledImage =
createSampledImage(imageType, image, sampler, loc, range);
// Skip creating OpSampledImage if the imageType is a sampled texture.
SpirvInstruction *sampledImage = nullptr;
if (isSampledTexture(imageType)) {
assert(!sampler &&
"sampler must be null when sampling from a sampled texture");
sampledImage = image;
} else {
sampledImage = createSampledImage(imageType, image, sampler, loc, range);
}

// TODO: Update ImageGather to accept minLod if necessary.
const auto mask = composeImageOperandsMask(
Expand Down
Loading