From 97630ed79a77020fdb0ce3eb5a451b9d9188b39a Mon Sep 17 00:00:00 2001 From: "Zhao, Maosu" Date: Tue, 21 Oct 2025 08:50:13 +0200 Subject: [PATCH] [DevSAN] Skip instrumentation if module has esimd kernel for MSAN/TSAN Device sanitizer doesn't support esimd kernel now. --- .../SPIRVSanitizerCommonUtils.h | 3 + .../Instrumentation/AddressSanitizer.cpp | 8 +-- .../Instrumentation/MemorySanitizer.cpp | 56 +++++++++++-------- .../SPIRVSanitizerCommonUtils.cpp | 7 +++ .../Instrumentation/ThreadSanitizer.cpp | 33 +++++++---- .../MemorySanitizer/SPIRV/sycl_esimd.ll | 34 +++++++++++ .../ThreadSanitizer/SPIRV/sycl_esimd.ll | 32 +++++++++++ 7 files changed, 130 insertions(+), 43 deletions(-) create mode 100644 llvm/test/Instrumentation/MemorySanitizer/SPIRV/sycl_esimd.ll create mode 100644 llvm/test/Instrumentation/ThreadSanitizer/SPIRV/sycl_esimd.ll diff --git a/llvm/include/llvm/Transforms/Instrumentation/SPIRVSanitizerCommonUtils.h b/llvm/include/llvm/Transforms/Instrumentation/SPIRVSanitizerCommonUtils.h index ceea4b5cd5179..108ccb18b3925 100644 --- a/llvm/include/llvm/Transforms/Instrumentation/SPIRVSanitizerCommonUtils.h +++ b/llvm/include/llvm/Transforms/Instrumentation/SPIRVSanitizerCommonUtils.h @@ -15,6 +15,7 @@ #include "llvm/ADT/SmallString.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" @@ -42,6 +43,8 @@ SmallString<128> computeKernelMetadataUniqueId(StringRef Prefix, SmallVectorImpl &KernelNamesBytes); +bool hasESIMDKernel(Module &M); + // Sync with sanitizer_common/sanitizer_common.hpp enum SanitizedKernelFlags : uint32_t { NO_CHECK = 0, diff --git a/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp index 2dad2cf19ee5b..8915217731b18 100644 --- a/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp @@ -1624,13 +1624,7 @@ PreservedAnalyses AddressSanitizerPass::run(Module &M, AsanSpirv->initializeCallbacks(); // FIXME: W/A skip instrumentation if this module has ESIMD - bool HasESIMD = false; - for (auto &F : M) { - if (F.hasMetadata("sycl_explicit_simd")) { - HasESIMD = true; - break; - } - } + bool HasESIMD = hasESIMDKernel(M); // Make sure "__AsanKernelMetadata" always exists ExtendSpirKernelArgs(M, FAM, HasESIMD); diff --git a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp index 9654d73186638..2ed593a63a169 100644 --- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp @@ -800,6 +800,7 @@ class MemorySanitizerOnSpirv { : M(M), C(M.getContext()), DL(M.getDataLayout()) { const auto &TargetTriple = Triple(M.getTargetTriple()); IsSPIRV = TargetTriple.isSPIROrSPIRV(); + HasESIMD = hasESIMDKernel(M); IntptrTy = DL.getIntPtrType(C); Int32Ty = Type::getInt32Ty(C); @@ -812,6 +813,8 @@ class MemorySanitizerOnSpirv { Constant *getOrCreateGlobalString(StringRef Name, StringRef Value, unsigned AddressSpace); + bool hasESIMD() { return HasESIMD; } + static bool isSupportedBuiltIn(StringRef Name); operator bool() const { return IsSPIRV; } @@ -834,6 +837,7 @@ class MemorySanitizerOnSpirv { LLVMContext &C; const DataLayout &DL; bool IsSPIRV; + bool HasESIMD; Type *IntptrTy; Type *Int32Ty; @@ -1242,34 +1246,35 @@ void MemorySanitizerOnSpirv::instrumentKernelsMetadata(int TrackOrigins) { // uptr unmangled_kernel_name_size // uptr sanitized_flags StructType *StructTy = StructType::get(IntptrTy, IntptrTy, IntptrTy); - for (Function &F : M) { - if (F.getCallingConv() != CallingConv::SPIR_KERNEL) - continue; + if (!HasESIMD) + for (Function &F : M) { + if (F.getCallingConv() != CallingConv::SPIR_KERNEL) + continue; - if (!F.hasFnAttribute(Attribute::SanitizeMemory) || - F.hasFnAttribute(Attribute::DisableSanitizerInstrumentation)) - continue; + if (!F.hasFnAttribute(Attribute::SanitizeMemory) || + F.hasFnAttribute(Attribute::DisableSanitizerInstrumentation)) + continue; - auto KernelName = F.getName(); - KernelNamesBytes.append(KernelName.begin(), KernelName.end()); - auto *KernelNameGV = getOrCreateGlobalString("__msan_kernel", KernelName, - kSpirOffloadConstantAS); + auto KernelName = F.getName(); + KernelNamesBytes.append(KernelName.begin(), KernelName.end()); + auto *KernelNameGV = getOrCreateGlobalString("__msan_kernel", KernelName, + kSpirOffloadConstantAS); - uintptr_t SanitizerFlags = 0; - SanitizerFlags |= ClSpirOffloadLocals ? SanitizedKernelFlags::CHECK_LOCALS - : SanitizedKernelFlags::NO_CHECK; - SanitizerFlags |= ClSpirOffloadPrivates - ? SanitizedKernelFlags::CHECK_PRIVATES - : SanitizedKernelFlags::NO_CHECK; - SanitizerFlags |= TrackOrigins != 0 - ? SanitizedKernelFlags::MSAN_TRACK_ORIGINS - : SanitizedKernelFlags::NO_CHECK; + uintptr_t SanitizerFlags = 0; + SanitizerFlags |= ClSpirOffloadLocals ? SanitizedKernelFlags::CHECK_LOCALS + : SanitizedKernelFlags::NO_CHECK; + SanitizerFlags |= ClSpirOffloadPrivates + ? SanitizedKernelFlags::CHECK_PRIVATES + : SanitizedKernelFlags::NO_CHECK; + SanitizerFlags |= TrackOrigins != 0 + ? SanitizedKernelFlags::MSAN_TRACK_ORIGINS + : SanitizedKernelFlags::NO_CHECK; - SpirKernelsMetadata.emplace_back(ConstantStruct::get( - StructTy, ConstantExpr::getPointerCast(KernelNameGV, IntptrTy), - ConstantInt::get(IntptrTy, KernelName.size()), - ConstantInt::get(IntptrTy, SanitizerFlags))); - } + SpirKernelsMetadata.emplace_back(ConstantStruct::get( + StructTy, ConstantExpr::getPointerCast(KernelNameGV, IntptrTy), + ConstantInt::get(IntptrTy, KernelName.size()), + ConstantInt::get(IntptrTy, SanitizerFlags))); + } // Create global variable to record spirv kernels' information ArrayType *ArrayTy = ArrayType::get(StructTy, SpirKernelsMetadata.size()); @@ -1361,6 +1366,9 @@ PreservedAnalyses MemorySanitizerPass::run(Module &M, MemorySanitizerOnSpirv MsanSpirv(M); Modified |= MsanSpirv.instrumentModule(Options.TrackOrigins); + // FIXME: W/A skip instrumentation if this module has ESIMD + if (MsanSpirv.hasESIMD()) + return PreservedAnalyses::none(); auto &FAM = AM.getResult(M).getManager(); for (Function &F : M) { diff --git a/llvm/lib/Transforms/Instrumentation/SPIRVSanitizerCommonUtils.cpp b/llvm/lib/Transforms/Instrumentation/SPIRVSanitizerCommonUtils.cpp index b11526ec9f159..4b8fb80ba8f69 100644 --- a/llvm/lib/Transforms/Instrumentation/SPIRVSanitizerCommonUtils.cpp +++ b/llvm/lib/Transforms/Instrumentation/SPIRVSanitizerCommonUtils.cpp @@ -81,4 +81,11 @@ computeKernelMetadataUniqueId(StringRef Prefix, return UniqueId; } +bool hasESIMDKernel(Module &M) { + for (auto &F : M) + if (F.hasMetadata("sycl_explicit_simd")) + return true; + return false; +} + } // namespace llvm diff --git a/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp index f04cf538caa17..145f94fad0adc 100644 --- a/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp +++ b/llvm/lib/Transforms/Instrumentation/ThreadSanitizer.cpp @@ -117,6 +117,7 @@ struct ThreadSanitizerOnSpirv { ThreadSanitizerOnSpirv(Module &M) : M(M), C(M.getContext()), DL(M.getDataLayout()) { IntptrTy = DL.getIntPtrType(C); + HasESIMD = hasESIMDKernel(M); } void initialize(); @@ -134,6 +135,8 @@ struct ThreadSanitizerOnSpirv { bool isUnsupportedSPIRAccess(Value *Addr, Instruction *Inst); + bool hasESIMD() { return HasESIMD; } + private: void instrumentGlobalVariables(); @@ -154,6 +157,7 @@ struct ThreadSanitizerOnSpirv { Module &M; LLVMContext &C; const DataLayout &DL; + bool HasESIMD; Type *IntptrTy; StringMap GlobalStringMap; @@ -695,20 +699,21 @@ void ThreadSanitizerOnSpirv::instrumentKernelsMetadata() { // uptr unmangled_kernel_name_size StructType *StructTy = StructType::get(IntptrTy, IntptrTy); - for (Function &F : M) { - if (F.getCallingConv() != CallingConv::SPIR_KERNEL) - continue; + if (!HasESIMD) + for (Function &F : M) { + if (F.getCallingConv() != CallingConv::SPIR_KERNEL) + continue; - if (isSupportedSPIRKernel(F)) { - auto KernelName = F.getName(); - KernelNamesBytes.append(KernelName.begin(), KernelName.end()); - auto *KernelNameGV = GetOrCreateGlobalString("__tsan_kernel", KernelName, - kSpirOffloadConstantAS); - SpirKernelsMetadata.emplace_back(ConstantStruct::get( - StructTy, ConstantExpr::getPointerCast(KernelNameGV, IntptrTy), - ConstantInt::get(IntptrTy, KernelName.size()))); + if (isSupportedSPIRKernel(F)) { + auto KernelName = F.getName(); + KernelNamesBytes.append(KernelName.begin(), KernelName.end()); + auto *KernelNameGV = GetOrCreateGlobalString( + "__tsan_kernel", KernelName, kSpirOffloadConstantAS); + SpirKernelsMetadata.emplace_back(ConstantStruct::get( + StructTy, ConstantExpr::getPointerCast(KernelNameGV, IntptrTy), + ConstantInt::get(IntptrTy, KernelName.size()))); + } } - } // Create global variable to record spirv kernels' information ArrayType *ArrayTy = ArrayType::get(StructTy, SpirKernelsMetadata.size()); @@ -1076,6 +1081,10 @@ bool ThreadSanitizer::sanitizeFunction(Function &F, bool SanitizeFunction = F.hasFnAttribute(Attribute::SanitizeThread); const DataLayout &DL = F.getDataLayout(); + // FIXME: W/A skip instrumentation if this module has ESIMD + if (Spirv && Spirv->hasESIMD()) + return false; + // Traverse all instructions, collect loads/stores/returns, check for calls. for (auto &BB : F) { for (auto &Inst : BB) { diff --git a/llvm/test/Instrumentation/MemorySanitizer/SPIRV/sycl_esimd.ll b/llvm/test/Instrumentation/MemorySanitizer/SPIRV/sycl_esimd.ll new file mode 100644 index 0000000000000..cc49a62b56384 --- /dev/null +++ b/llvm/test/Instrumentation/MemorySanitizer/SPIRV/sycl_esimd.ll @@ -0,0 +1,34 @@ +; RUN: opt < %s -passes=msan -msan-instrumentation-with-call-threshold=0 -msan-eager-checks=1 -msan-spir-privates=0 -S | FileCheck %s + +target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64-G1" +target triple = "spir64-unknown-unknown" + +;CHECK: @__MsanKernelMetadata +;CHECK-SAME: [0 x { i64, i64, i64 }] + +define spir_kernel void @test(ptr addrspace(1) noundef align 4 %_arg_array) sanitize_memory { +; CHECK-LABEL: define spir_kernel void @test +entry: + %0 = load i32, ptr addrspace(1) %_arg_array, align 4 + %call = call spir_func i32 @foo(i32 %0) + store i32 %call, ptr addrspace(1) %_arg_array, align 4 +; CHECK-NOT: call void @__msan_maybe_warning + ret void +} + +define spir_kernel void @test_esimd(ptr addrspace(1) noundef align 4 %_arg_array) sanitize_memory !sycl_explicit_simd !0 { +; CHECK-LABEL: define spir_kernel void @test_esimd +entry: + %0 = load i32, ptr addrspace(1) %_arg_array, align 4 + %call = call spir_func i32 @foo(i32 %0) + store i32 %call, ptr addrspace(1) %_arg_array, align 4 +; CHECK-NOT: call void @__msan_maybe_warning + ret void +} + +define spir_func i32 @foo(i32 %data) sanitize_memory { +entry: + ret i32 %data +} + +!0 = !{} diff --git a/llvm/test/Instrumentation/ThreadSanitizer/SPIRV/sycl_esimd.ll b/llvm/test/Instrumentation/ThreadSanitizer/SPIRV/sycl_esimd.ll new file mode 100644 index 0000000000000..c73b07fe097d2 --- /dev/null +++ b/llvm/test/Instrumentation/ThreadSanitizer/SPIRV/sycl_esimd.ll @@ -0,0 +1,32 @@ +; RUN: opt < %s -passes='function(tsan),module(tsan-module)' -tsan-instrument-func-entry-exit=0 -tsan-instrument-memintrinsics=0 -S | FileCheck %s +target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64-G1" +target triple = "spir64-unknown-unknown" + +; CHECK: @__TsanKernelMetadata +; CHECK-SAME: [0 x { i64, i64 }] + +; Function Attrs: sanitize_thread +define spir_kernel void @test(ptr addrspace(4) %a) #0 { +; CHECK-LABEL: void @test +entry: + %tmp1 = load i8, ptr addrspace(4) %a, align 1 + %inc = add i8 %tmp1, 1 + ; CHECK-NOT: call void @__tsan_write + store i8 %inc, ptr addrspace(4) %a, align 1 + ret void +} + +; Function Attrs: sanitize_thread +define spir_kernel void @test_esimd(ptr addrspace(4) %a) #0 !sycl_explicit_simd !0 { +; CHECK-LABEL: void @test_esimd +entry: + %tmp1 = load i16, ptr addrspace(4) %a, align 2 + %inc = add i16 %tmp1, 1 + ; CHECK-NOT: call void @__tsan_write + store i16 %inc, ptr addrspace(4) %a, align 2 + ret void +} + +attributes #0 = { sanitize_thread } + +!0 = !{}