diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td index f2a12f68d481b..1bc3c63646fdd 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAttributes.td @@ -97,6 +97,20 @@ def SPIRV_CooperativeMatrixPropertiesNVAttr : let assemblyFormat = "`<` struct(params) `>`"; } +def SPIRV_CacheControlLoadINTELAttr : + SPIRV_Attr<"CacheControlLoadINTEL", "cache_control_load_intel"> { + let parameters = (ins "unsigned":$cache_level, + "mlir::spirv::LoadCacheControl":$load_cache_control); + let assemblyFormat = "`<` struct(params) `>`"; +} + +def SPIRV_CacheControlStoreINTELAttr : + SPIRV_Attr<"CacheControlStoreINTEL", "cache_control_store_intel"> { + let parameters = (ins "unsigned":$cache_level, + "mlir::spirv::StoreCacheControl":$store_cache_control); + let assemblyFormat = "`<` struct(params) `>`"; +} + def SPIRV_CooperativeMatrixPropertiesNVArrayAttr : TypedArrayAttrBase; diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index 3b7da9b44a08f..252d9319fccc5 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -400,6 +400,7 @@ def SPV_INTEL_fp_fast_math_mode : I32EnumAttrCase<"SPV_INTEL_fp def SPV_INTEL_memory_access_aliasing : I32EnumAttrCase<"SPV_INTEL_memory_access_aliasing", 4028>; def SPV_INTEL_split_barrier : I32EnumAttrCase<"SPV_INTEL_split_barrier", 4029>; def SPV_INTEL_bfloat16_conversion : I32EnumAttrCase<"SPV_INTEL_bfloat16_conversion", 4031>; +def SPV_INTEL_cache_controls : I32EnumAttrCase<"SPV_INTEL_cache_controls", 4032>; def SPV_NV_compute_shader_derivatives : I32EnumAttrCase<"SPV_NV_compute_shader_derivatives", 5000>; def SPV_NV_cooperative_matrix : I32EnumAttrCase<"SPV_NV_cooperative_matrix", 5001>; @@ -459,7 +460,8 @@ def SPIRV_ExtensionAttr : SPV_INTEL_fpga_reg, SPV_INTEL_long_constant_composite, SPV_INTEL_optnone, SPV_INTEL_debug_module, SPV_INTEL_fp_fast_math_mode, SPV_INTEL_memory_access_aliasing, SPV_INTEL_split_barrier, - SPV_INTEL_bfloat16_conversion, SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix, + SPV_INTEL_bfloat16_conversion, SPV_INTEL_cache_controls, + SPV_NV_compute_shader_derivatives, SPV_NV_cooperative_matrix, SPV_NV_fragment_shader_barycentric, SPV_NV_geometry_shader_passthrough, SPV_NV_mesh_shader, SPV_NV_ray_tracing, SPV_NV_sample_mask_override_coverage, SPV_NV_shader_image_footprint, SPV_NV_shader_sm_builtins, @@ -1415,6 +1417,12 @@ def SPIRV_C_Bfloat16ConversionINTEL : I32EnumAttrCase<"B ]; } +def SPIRV_C_CacheControlsINTEL : I32EnumAttrCase<"CacheControlsINTEL", 6441> { + list availability = [ + Extension<[SPV_INTEL_cache_controls]> + ]; +} + def SPIRV_CapabilityAttr : SPIRV_I32EnumAttr<"Capability", "valid SPIR-V Capability", "capability", [ SPIRV_C_Matrix, SPIRV_C_Addresses, SPIRV_C_Linkage, SPIRV_C_Kernel, SPIRV_C_Float16, @@ -1507,7 +1515,8 @@ def SPIRV_CapabilityAttr : SPIRV_C_UniformTexelBufferArrayNonUniformIndexing, SPIRV_C_StorageTexelBufferArrayNonUniformIndexing, SPIRV_C_ShaderViewportIndexLayerEXT, SPIRV_C_ShaderViewportMaskNV, - SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL + SPIRV_C_ShaderStereoViewNV, SPIRV_C_Bfloat16ConversionINTEL, + SPIRV_C_CacheControlsINTEL ]>; def SPIRV_AM_Logical : I32EnumAttrCase<"Logical", 0>; @@ -2623,6 +2632,16 @@ def SPIRV_D_MediaBlockIOINTEL : I32EnumAttrCase<"MediaBlockIOIN Capability<[SPIRV_C_VectorComputeINTEL]> ]; } +def SPIRV_D_CacheControlLoadINTEL : I32EnumAttrCase<"CacheControlLoadINTEL", 6442> { + list availability = [ + Capability<[SPIRV_C_CacheControlsINTEL]> + ]; +} +def SPIRV_D_CacheControlStoreINTEL : I32EnumAttrCase<"CacheControlStoreINTEL", 6443> { + list availability = [ + Capability<[SPIRV_C_CacheControlsINTEL]> + ]; +} def SPIRV_DecorationAttr : SPIRV_I32EnumAttr<"Decoration", "valid SPIR-V Decoration", "decoration", [ @@ -2658,7 +2677,8 @@ def SPIRV_DecorationAttr : SPIRV_D_FuseLoopsInFunctionINTEL, SPIRV_D_AliasScopeINTEL, SPIRV_D_NoAliasINTEL, SPIRV_D_BufferLocationINTEL, SPIRV_D_IOPipeStorageINTEL, SPIRV_D_FunctionFloatingPointModeINTEL, SPIRV_D_SingleElementVectorINTEL, - SPIRV_D_VectorComputeCallableFunctionINTEL, SPIRV_D_MediaBlockIOINTEL + SPIRV_D_VectorComputeCallableFunctionINTEL, SPIRV_D_MediaBlockIOINTEL, + SPIRV_D_CacheControlLoadINTEL, SPIRV_D_CacheControlStoreINTEL ]>; def SPIRV_D_1D : I32EnumAttrCase<"Dim1D", 0> { @@ -4092,6 +4112,32 @@ def SPIRV_KHR_CooperativeMatrixOperandsAttr : SPIRV_KHR_CMO_Result_Signed, SPIRV_KHR_CMO_AccSat ]>; +def SPIRV_INTEL_LCC_Uncached : I32EnumAttrCase<"Uncached", 0>; +def SPIRV_INTEL_LCC_Cached : I32EnumAttrCase<"Cached", 1>; +def SPIRV_INTEL_LCC_Streaming : I32EnumAttrCase<"Streaming", 2>; +def SPIRV_INTEL_LCC_InvalidateAfterRead : I32EnumAttrCase<"InvalidateAfterR", 3>; +def SPIRV_INTEL_LCC_ConstCached : I32EnumAttrCase<"ConstCached", 4>; + +def SPIRV_INTEL_LoadCacheControlAttr : + SPIRV_I32EnumAttr<"LoadCacheControl", "valid SPIR-V LoadCacheControl", + "load_cache_control", [ + SPIRV_INTEL_LCC_Uncached, SPIRV_INTEL_LCC_Cached, + SPIRV_INTEL_LCC_Streaming, SPIRV_INTEL_LCC_InvalidateAfterRead, + SPIRV_INTEL_LCC_ConstCached + ]>; + +def SPIRV_INTEL_SCC_Uncached : I32EnumAttrCase<"Uncached", 0>; +def SPIRV_INTEL_SCC_WriteThrough : I32EnumAttrCase<"WriteThrough", 1>; +def SPIRV_INTEL_SCC_WriteBack : I32EnumAttrCase<"WriteBack", 2>; +def SPIRV_INTEL_SCC_Streaming : I32EnumAttrCase<"Streaming", 3>; + +def SPIRV_INTEL_StoreCacheControlAttr : + SPIRV_I32EnumAttr<"StoreCacheControl", "valid SPIR-V StoreCacheControl", + "store_cache_control", [ + SPIRV_INTEL_SCC_Uncached, SPIRV_INTEL_SCC_WriteThrough, + SPIRV_INTEL_SCC_WriteBack, SPIRV_INTEL_SCC_Streaming + ]>; + //===----------------------------------------------------------------------===// // SPIR-V attribute definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index 462d3e326b6c2..04469f1933819 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -226,6 +226,28 @@ spirv::Deserializer::processMemoryModel(ArrayRef operands) { return success(); } +template +LogicalResult deserializeCacheControlDecoration( + Location loc, OpBuilder &opBuilder, + DenseMap &decorations, ArrayRef words, + StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) { + if (words.size() != 4) { + return emitError(loc, "OpDecoration with ") + << decorationName << "needs a cache control integer literal and a " + << cacheControlKind << " cache control literal"; + } + unsigned cacheLevel = words[2]; + auto cacheControlAttr = static_cast(words[3]); + auto value = opBuilder.getAttr(cacheLevel, cacheControlAttr); + SmallVector attrs; + if (auto attrList = + llvm::dyn_cast_or_null(decorations[words[0]].get(symbol))) + llvm::append_range(attrs, attrList); + attrs.push_back(value); + decorations[words[0]].set(symbol, opBuilder.getArrayAttr(attrs)); + return success(); +} + LogicalResult spirv::Deserializer::processDecoration(ArrayRef words) { // TODO: This function should also be auto-generated. For now, since only a // few decorations are processed/handled in a meaningful manner, going with a @@ -339,6 +361,24 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef words) { decorations[words[0]].set( symbol, opBuilder.getI32IntegerAttr(static_cast(words[2]))); break; + case spirv::Decoration::CacheControlLoadINTEL: { + LogicalResult res = deserializeCacheControlDecoration< + CacheControlLoadINTELAttr, LoadCacheControlAttr, LoadCacheControl>( + unknownLoc, opBuilder, decorations, words, symbol, decorationName, + "load"); + if (failed(res)) + return res; + break; + } + case spirv::Decoration::CacheControlStoreINTEL: { + LogicalResult res = deserializeCacheControlDecoration< + CacheControlStoreINTELAttr, StoreCacheControlAttr, StoreCacheControl>( + unknownLoc, opBuilder, decorations, words, symbol, decorationName, + "store"); + if (failed(res)) + return res; + break; + } default: return emitError(unknownLoc, "unhandled Decoration : '") << decorationName; } diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index f355982e9ed88..1f4f5d7f764db 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -217,10 +217,42 @@ static std::string getDecorationName(StringRef attrName) { // similar here if (attrName == "fp_rounding_mode") return "FPRoundingMode"; + // convertToCamelFromSnakeCase will not capitalize "INTEL". + if (attrName == "cache_control_load_intel") + return "CacheControlLoadINTEL"; + if (attrName == "cache_control_store_intel") + return "CacheControlStoreINTEL"; return llvm::convertToCamelFromSnakeCase(attrName, /*capitalizeFirst=*/true); } +template +LogicalResult processDecorationList(Location loc, Decoration decoration, + Attribute attrList, StringRef attrName, + EmitF emitter) { + auto arrayAttr = dyn_cast(attrList); + if (!arrayAttr) { + return emitError(loc, "expecting array attribute of ") + << attrName << " for " << stringifyDecoration(decoration); + } + if (arrayAttr.empty()) { + return emitError(loc, "expecting non-empty array attribute of ") + << attrName << " for " << stringifyDecoration(decoration); + } + for (Attribute attr : arrayAttr.getValue()) { + auto cacheControlAttr = dyn_cast(attr); + if (!cacheControlAttr) { + return emitError(loc, "expecting array attribute of ") + << attrName << " for " << stringifyDecoration(decoration); + } + // This named attribute encodes several decorations. Emit one per + // element in the array. + if (failed(emitter(cacheControlAttr))) + return failure(); + } + return success(); +} + LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID, Decoration decoration, Attribute attr) { @@ -294,6 +326,26 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID, return emitError(loc, "expected unit attribute or decoration attribute for ") << stringifyDecoration(decoration); + case spirv::Decoration::CacheControlLoadINTEL: + return processDecorationList( + loc, decoration, attr, "CacheControlLoadINTEL", + [&](CacheControlLoadINTELAttr attr) { + unsigned cacheLevel = attr.getCacheLevel(); + LoadCacheControl loadCacheControl = attr.getLoadCacheControl(); + return emitDecoration( + resultID, decoration, + {cacheLevel, static_cast(loadCacheControl)}); + }); + case spirv::Decoration::CacheControlStoreINTEL: + return processDecorationList( + loc, decoration, attr, "CacheControlStoreINTEL", + [&](CacheControlStoreINTELAttr attr) { + unsigned cacheLevel = attr.getCacheLevel(); + StoreCacheControl storeCacheControl = attr.getStoreCacheControl(); + return emitDecoration( + resultID, decoration, + {cacheLevel, static_cast(storeCacheControl)}); + }); default: return emitError(loc, "unhandled decoration ") << stringifyDecoration(decoration); diff --git a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir index 53a1015de75bc..66c70e816d413 100644 --- a/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/intel-ext-ops.mlir @@ -69,3 +69,21 @@ spirv.func @bf16_to_f32_vec_unsupported(%arg0 : vector<2xi16>) "None" { %0 = spirv.INTEL.ConvertBF16ToF %arg0 : vector<2xi16> to vector<3xf32> spirv.Return } + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.INTEL.CacheControls +//===----------------------------------------------------------------------===// + +spirv.module Logical GLSL450 requires #spirv.vce { + spirv.func @foo() "None" { + // CHECK: spirv.Variable {cache_control_load_intel = [#spirv.cache_control_load_intel, #spirv.cache_control_load_intel, #spirv.cache_control_load_intel]} : !spirv.ptr + %0 = spirv.Variable {cache_control_load_intel = [#spirv.cache_control_load_intel, #spirv.cache_control_load_intel, #spirv.cache_control_load_intel]} : !spirv.ptr + // CHECK: spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel, #spirv.cache_control_store_intel, #spirv.cache_control_store_intel]} : !spirv.ptr + %1 = spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel, #spirv.cache_control_store_intel, #spirv.cache_control_store_intel]} : !spirv.ptr + spirv.Return + } +} + +// ----- diff --git a/mlir/test/Target/SPIRV/decorations.mlir b/mlir/test/Target/SPIRV/decorations.mlir index 0a29290b6a6fa..d66ac74dc4ef9 100644 --- a/mlir/test/Target/SPIRV/decorations.mlir +++ b/mlir/test/Target/SPIRV/decorations.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-translate -no-implicit-module -split-input-file -test-spirv-roundtrip %s | FileCheck %s +// RUN: mlir-translate -no-implicit-module -split-input-file -test-spirv-roundtrip -verify-diagnostics %s | FileCheck %s spirv.module Logical GLSL450 requires #spirv.vce { // CHECK: location = 0 : i32 @@ -107,3 +107,47 @@ spirv.func @fp_rounding_mode(%arg: f32) -> f16 "None" { spirv.ReturnValue %0 : f16 } } + +// ----- + +// CHECK-LABEL: spirv.module Logical GLSL450 requires #spirv.vce { + +spirv.module Logical GLSL450 requires #spirv.vce { + spirv.func @cache_controls() "None" { + // CHECK: spirv.Variable {cache_control_load_intel = [#spirv.cache_control_load_intel, #spirv.cache_control_load_intel, #spirv.cache_control_load_intel]} : !spirv.ptr + %0 = spirv.Variable {cache_control_load_intel = [#spirv.cache_control_load_intel, #spirv.cache_control_load_intel, #spirv.cache_control_load_intel]} : !spirv.ptr + // CHECK: spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel, #spirv.cache_control_store_intel, #spirv.cache_control_store_intel]} : !spirv.ptr + %1 = spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel, #spirv.cache_control_store_intel, #spirv.cache_control_store_intel]} : !spirv.ptr + spirv.Return + } +} + +// ----- + +spirv.module Logical GLSL450 requires #spirv.vce { + spirv.func @cache_controls_invalid_type() "None" { + // expected-error@below {{expecting array attribute of CacheControlLoadINTEL for CacheControlLoadINTEL}} + %0 = spirv.Variable {cache_control_load_intel = #spirv.cache_control_load_intel} : !spirv.ptr + spirv.Return + } +} + +// ----- + +spirv.module Logical GLSL450 requires #spirv.vce { + spirv.func @cache_controls_invalid_type() "None" { + // expected-error@below {{expecting array attribute of CacheControlStoreINTEL for CacheControlStoreINTEL}} + %0 = spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel, 0 : i32]} : !spirv.ptr + spirv.Return + } +} + +// ----- + +spirv.module Logical GLSL450 requires #spirv.vce { + spirv.func @cache_controls_invalid_type() "None" { + // expected-error@below {{expecting non-empty array attribute of CacheControlStoreINTEL for CacheControlStoreINTEL}} + %0 = spirv.Variable {cache_control_store_intel = []} : !spirv.ptr + spirv.Return + } +}