diff --git a/source/loader/layers/sanitizer/asan/asan_ddi.cpp b/source/loader/layers/sanitizer/asan/asan_ddi.cpp index 774ce3a61d..741b4d421c 100644 --- a/source/loader/layers/sanitizer/asan/asan_ddi.cpp +++ b/source/loader/layers/sanitizer/asan/asan_ddi.cpp @@ -55,6 +55,9 @@ ur_result_t setupContext(ur_context_handle_t Context, uint32_t numDevices, bool isInstrumentedKernel(ur_kernel_handle_t hKernel) { auto hProgram = GetProgram(hKernel); auto PI = getAsanInterceptor()->getProgramInfo(hProgram); + if (PI == nullptr) { + return false; + } return PI->isKernelInstrumented(hKernel); } @@ -290,8 +293,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramRetain( UR_CALL(pfnRetain(hProgram)); auto ProgramInfo = getAsanInterceptor()->getProgramInfo(hProgram); - UR_ASSERT(ProgramInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE); - ProgramInfo->RefCount++; + if (ProgramInfo != nullptr) { + ProgramInfo->RefCount++; + } return UR_RESULT_SUCCESS; } @@ -364,6 +368,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramLink( UR_CALL(pfnProgramLink(hContext, count, phPrograms, pOptions, phProgram)); + UR_CALL(getAsanInterceptor()->insertProgram(*phProgram)); UR_CALL(getAsanInterceptor()->registerProgram(*phProgram)); return UR_RESULT_SUCCESS; @@ -395,6 +400,7 @@ ur_result_t UR_APICALL urProgramLinkExp( UR_CALL(pfnProgramLinkExp(hContext, numDevices, phDevices, count, phPrograms, pOptions, phProgram)); + UR_CALL(getAsanInterceptor()->insertProgram(*phProgram)); UR_CALL(getAsanInterceptor()->registerProgram(*phProgram)); return UR_RESULT_SUCCESS; @@ -417,8 +423,7 @@ ur_result_t UR_APICALL urProgramRelease( UR_CALL(pfnProgramRelease(hProgram)); auto ProgramInfo = getAsanInterceptor()->getProgramInfo(hProgram); - UR_ASSERT(ProgramInfo != nullptr, UR_RESULT_ERROR_INVALID_VALUE); - if (--ProgramInfo->RefCount == 0) { + if (ProgramInfo != nullptr && --ProgramInfo->RefCount == 0) { UR_CALL(getAsanInterceptor()->unregisterProgram(hProgram)); UR_CALL(getAsanInterceptor()->eraseProgram(hProgram)); } diff --git a/source/loader/layers/sanitizer/asan/asan_interceptor.cpp b/source/loader/layers/sanitizer/asan/asan_interceptor.cpp index 271d846990..a818cf3642 100644 --- a/source/loader/layers/sanitizer/asan/asan_interceptor.cpp +++ b/source/loader/layers/sanitizer/asan/asan_interceptor.cpp @@ -221,25 +221,28 @@ ur_result_t AsanInterceptor::releaseMemory(ur_context_handle_t Context, if (ReleaseList.size()) { std::scoped_lock Guard(m_AllocationMapMutex); for (auto &It : ReleaseList) { + auto ToFreeAllocInfo = It->second; getContext()->logger.info("Quarantine Free: {}", - (void *)It->second->AllocBegin); + (void *)ToFreeAllocInfo->AllocBegin); - ContextInfo->Stats.UpdateUSMRealFreed(AllocInfo->AllocSize, - AllocInfo->getRedzoneSize()); + ContextInfo->Stats.UpdateUSMRealFreed( + ToFreeAllocInfo->AllocSize, ToFreeAllocInfo->getRedzoneSize()); - m_AllocationMap.erase(It); - if (AllocInfo->Type == AllocType::HOST_USM) { + if (ToFreeAllocInfo->Type == AllocType::HOST_USM) { for (auto &Device : ContextInfo->DeviceList) { UR_CALL(getDeviceInfo(Device)->Shadow->ReleaseShadow( - AllocInfo)); + ToFreeAllocInfo)); } } else { - UR_CALL(getDeviceInfo(AllocInfo->Device) - ->Shadow->ReleaseShadow(AllocInfo)); + UR_CALL(getDeviceInfo(ToFreeAllocInfo->Device) + ->Shadow->ReleaseShadow(ToFreeAllocInfo)); } UR_CALL(getContext()->urDdiTable.USM.pfnFree( - Context, (void *)(It->second->AllocBegin))); + Context, (void *)(ToFreeAllocInfo->AllocBegin))); + + // Erase it at last to avoid use-after-free. + m_AllocationMap.erase(It); } } ContextInfo->Stats.UpdateUSMFreed(AllocInfo->AllocSize); @@ -433,6 +436,7 @@ ur_result_t AsanInterceptor::registerProgram(ur_program_handle_t Program) { ur_result_t AsanInterceptor::unregisterProgram(ur_program_handle_t Program) { auto ProgramInfo = getProgramInfo(Program); + assert(ProgramInfo != nullptr && "unregistered program!"); for (auto AI : ProgramInfo->AllocInfoForGlobals) { UR_CALL(getDeviceInfo(AI->Device)->Shadow->ReleaseShadow(AI)); @@ -480,6 +484,7 @@ ur_result_t AsanInterceptor::registerSpirKernels(ur_program_handle_t Program) { } auto PI = getProgramInfo(Program); + assert(PI != nullptr && "unregistered program!"); for (const auto &SKI : SKInfo) { if (SKI.Size == 0) { continue; @@ -516,6 +521,7 @@ AsanInterceptor::registerDeviceGlobals(ur_program_handle_t Program) { auto Context = GetContext(Program); auto ContextInfo = getContextInfo(Context); auto ProgramInfo = getProgramInfo(Program); + assert(ProgramInfo != nullptr && "unregistered program!"); for (auto Device : Devices) { ManagedQueue Queue(Context, Device); diff --git a/source/loader/layers/sanitizer/asan/asan_interceptor.hpp b/source/loader/layers/sanitizer/asan/asan_interceptor.hpp index 926be1388e..09075e1aaa 100644 --- a/source/loader/layers/sanitizer/asan/asan_interceptor.hpp +++ b/source/loader/layers/sanitizer/asan/asan_interceptor.hpp @@ -266,8 +266,10 @@ class AsanInterceptor { std::shared_ptr getProgramInfo(ur_program_handle_t Program) { std::shared_lock Guard(m_ProgramMapMutex); - assert(m_ProgramMap.find(Program) != m_ProgramMap.end()); - return m_ProgramMap[Program]; + if (m_ProgramMap.find(Program) != m_ProgramMap.end()) { + return m_ProgramMap[Program]; + } + return nullptr; } std::shared_ptr getKernelInfo(ur_kernel_handle_t Kernel) { diff --git a/source/loader/layers/sanitizer/asan/asan_shadow.cpp b/source/loader/layers/sanitizer/asan/asan_shadow.cpp index 59897a426f..d3f3761dc7 100644 --- a/source/loader/layers/sanitizer/asan/asan_shadow.cpp +++ b/source/loader/layers/sanitizer/asan/asan_shadow.cpp @@ -250,6 +250,7 @@ ur_result_t ShadowMemoryGPU::ReleaseShadow(std::shared_ptr AI) { getContext()->logger.debug("urVirtualMemUnmap: {} ~ {}", (void *)MappedPtr, (void *)(MappedPtr + PageSize - 1)); + VirtualMemMaps.erase(MappedPtr); } }