From 9d804237831db8ea903119efb68c384b9c1d90ea Mon Sep 17 00:00:00 2001 From: Sven van Haastregt Date: Mon, 20 Oct 2025 10:14:02 +0100 Subject: [PATCH] Always emit coopmat conversions as SPIR-V friendly IR calls Map all cooperative matrix type conversions to SPIR-V friendly IR calls, regardless of the environment specified. In particular, do not attempt to map such conversions to the OpenCL `convert` builtin. The SPIR-V TargetExtType is already encoded in the function suffix, so the previous translation was an odd hybrid between OpenCL and SPIR-V friendly IR. --- lib/SPIRV/SPIRVToOCL.cpp | 7 +++++++ .../SPV_KHR_cooperative_matrix/conversion_instructions.ll | 5 +++++ 2 files changed, 12 insertions(+) diff --git a/lib/SPIRV/SPIRVToOCL.cpp b/lib/SPIRV/SPIRVToOCL.cpp index 7b8003b851..e2e36bc8b7 100644 --- a/lib/SPIRV/SPIRVToOCL.cpp +++ b/lib/SPIRV/SPIRVToOCL.cpp @@ -676,6 +676,13 @@ void SPIRVToOCLBase::visitCallGenericCastToPtrExplicitBuiltIn(CallInst *CI, void SPIRVToOCLBase::visitCallSPIRVCvtBuiltin(CallInst *CI, Op OC, StringRef DemangledName) { + if (auto *TET = + dyn_cast(CI->getFunctionType()->getReturnType())) { + // Preserve any cooperative matrix type conversions as SPIR-V calls. + if (TET->getName() == "spirv.CooperativeMatrixKHR") { + return; + } + } std::string CastBuiltInName; if (isCvtFromUnsignedOpCode(OC)) CastBuiltInName = "u"; diff --git a/test/extensions/KHR/SPV_KHR_cooperative_matrix/conversion_instructions.ll b/test/extensions/KHR/SPV_KHR_cooperative_matrix/conversion_instructions.ll index d92b377c98..8e852a405d 100644 --- a/test/extensions/KHR/SPV_KHR_cooperative_matrix/conversion_instructions.ll +++ b/test/extensions/KHR/SPV_KHR_cooperative_matrix/conversion_instructions.ll @@ -17,6 +17,11 @@ ; RUN: llvm-dis %t.rev.bc ; RUN: FileCheck < %t.rev.ll %s --check-prefix=CHECK-LLVM +; Ensure cooperative matrix conversions are mapped to SPIR-V friendly IR calls. +; RUN: llvm-spirv -r --spirv-target-env=CL2.0 %t.spv -o %t.rev.bc +; RUN: llvm-dis %t.rev.bc +; RUN: FileCheck < %t.rev.ll %s --check-prefix=CHECK-LLVM + ; CHECK-SPIRV: TypeInt [[#TypeInt32:]] 32 0 ; CHECK-SPIRV: TypeInt [[#TypeInt16:]] 16 0 ; CHECK-SPIRV: TypeInt [[#TypeInt8:]] 8 0