Skip to content

Commit f2c4a82

Browse files
committed
Add sanddi interception
1 parent 6addfea commit f2c4a82

File tree

2 files changed

+157
-0
lines changed

2 files changed

+157
-0
lines changed

source/loader/layers/sanitizer/common.hpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,41 @@ inline constexpr uptr ComputeRZLog(uptr user_requested_size) {
6565
return rz_log;
6666
}
6767

68+
/// Returns the next integer (mod 2**64) that is greater than or equal to
69+
/// \p Value and is a multiple of \p Align. \p Align must be non-zero.
70+
///
71+
/// Examples:
72+
/// \code
73+
/// alignTo(5, 8) = 8
74+
/// alignTo(17, 8) = 24
75+
/// alignTo(~0LL, 8) = 0
76+
/// alignTo(321, 255) = 510
77+
/// \endcode
78+
inline uint64_t AlignTo(uint64_t Value, uint64_t Align) {
79+
assert(Align != 0u && "Align can't be 0.");
80+
return (Value + Align - 1) / Align * Align;
81+
}
82+
83+
inline uint64_t GetSizeAndRedzoneSizeForLocal(uint64_t Size,
84+
uint64_t Granularity,
85+
uint64_t Alignment) {
86+
uint64_t Res = 0;
87+
if (Size <= 4) {
88+
Res = 16;
89+
} else if (Size <= 16) {
90+
Res = 32;
91+
} else if (Size <= 128) {
92+
Res = Size + 32;
93+
} else if (Size <= 512) {
94+
Res = Size + 64;
95+
} else if (Size <= 4096) {
96+
Res = Size + 128;
97+
} else {
98+
Res = Size + 256;
99+
}
100+
return AlignTo(std::max(Res, 2 * Granularity), Alignment);
101+
}
102+
68103
// ================================================================
69104

70105
// Trace an internal UR call; returns in case of an error.

source/loader/layers/sanitizer/ur_sanddi.cpp

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueKernelLaunch(
217217
USMLaunchInfo LaunchInfo(GetContext(hQueue), GetDevice(hQueue),
218218
pGlobalWorkSize, pLocalWorkSize, pGlobalWorkOffset,
219219
workDim);
220+
UR_CALL(LaunchInfo.initialize());
220221

221222
UR_CALL(context.interceptor->preLaunchKernel(hKernel, hQueue, LaunchInfo));
222223

@@ -317,6 +318,90 @@ __urdlllocal ur_result_t UR_APICALL urContextRelease(
317318
return result;
318319
}
319320

321+
///////////////////////////////////////////////////////////////////////////////
322+
/// @brief Intercept function for urKernelCreate
323+
__urdlllocal ur_result_t UR_APICALL urKernelCreate(
324+
ur_program_handle_t hProgram, ///< [in] handle of the program instance
325+
const char *pKernelName, ///< [in] pointer to null-terminated string.
326+
ur_kernel_handle_t
327+
*phKernel ///< [out] pointer to handle of kernel object created.
328+
) {
329+
auto pfnCreate = context.urDdiTable.Kernel.pfnCreate;
330+
331+
if (nullptr == pfnCreate) {
332+
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
333+
}
334+
335+
context.logger.debug("==== urKernelCreate");
336+
337+
UR_CALL(pfnCreate(hProgram, pKernelName, phKernel));
338+
UR_CALL(context.interceptor->insertKernel(*phKernel));
339+
340+
return UR_RESULT_SUCCESS;
341+
}
342+
343+
///////////////////////////////////////////////////////////////////////////////
344+
/// @brief Intercept function for urKernelRelease
345+
__urdlllocal ur_result_t urKernelRelease(
346+
ur_kernel_handle_t hKernel ///< [in] handle for the Kernel to release
347+
) {
348+
auto pfnRelease = context.urDdiTable.Kernel.pfnRelease;
349+
350+
if (nullptr == pfnRelease) {
351+
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
352+
}
353+
354+
context.logger.debug("==== urKernelRelease");
355+
UR_CALL(pfnRelease(hKernel));
356+
357+
if (auto KernelInfo = context.interceptor->getKernelInfo(hKernel)) {
358+
uint32_t RefCount;
359+
UR_CALL(context.urDdiTable.Kernel.pfnGetInfo(
360+
hKernel, UR_KERNEL_INFO_REFERENCE_COUNT, sizeof(RefCount),
361+
&RefCount, nullptr));
362+
if (RefCount == 1) {
363+
UR_CALL(context.interceptor->eraseKernel(hKernel));
364+
}
365+
}
366+
367+
return UR_RESULT_SUCCESS;
368+
}
369+
370+
///////////////////////////////////////////////////////////////////////////////
371+
/// @brief Intercept function for urKernelSetArgLocal
372+
__urdlllocal ur_result_t UR_APICALL urKernelSetArgLocal(
373+
ur_kernel_handle_t hKernel, ///< [in] handle of the kernel object
374+
uint32_t argIndex, ///< [in] argument index in range [0, num args - 1]
375+
size_t
376+
argSize, ///< [in] size of the local buffer to be allocated by the runtime
377+
const ur_kernel_arg_local_properties_t
378+
*pProperties ///< [in][optional] pointer to local buffer properties.
379+
) {
380+
auto pfnSetArgLocal = context.urDdiTable.Kernel.pfnSetArgLocal;
381+
382+
if (nullptr == pfnSetArgLocal) {
383+
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
384+
}
385+
386+
context.logger.debug("==== urKernelSetArgLocal (argIndex={}, argSize={})",
387+
argIndex, argSize);
388+
389+
{
390+
auto KI = context.interceptor->getKernelInfo(hKernel);
391+
std::scoped_lock<ur_shared_mutex> Guard(KI->Mutex);
392+
// TODO: get local variable alignment
393+
auto argSizeWithRZ = GetSizeAndRedzoneSizeForLocal(
394+
argSize, ASAN_SHADOW_GRANULARITY, ASAN_SHADOW_GRANULARITY);
395+
KI->LocalArgs[argIndex] = LocalArgsInfo{argSize, argSizeWithRZ};
396+
argSize = argSizeWithRZ;
397+
}
398+
399+
ur_result_t result =
400+
pfnSetArgLocal(hKernel, argIndex, argSize, pProperties);
401+
402+
return result;
403+
}
404+
320405
///////////////////////////////////////////////////////////////////////////////
321406
/// @brief Exported function for filling application's Context table
322407
/// with current process' addresses
@@ -410,6 +495,38 @@ __urdlllocal ur_result_t UR_APICALL urGetProgramExpProcAddrTable(
410495
return result;
411496
}
412497
///////////////////////////////////////////////////////////////////////////////
498+
/// @brief Exported function for filling application's Kernel table
499+
/// with current process' addresses
500+
///
501+
/// @returns
502+
/// - ::UR_RESULT_SUCCESS
503+
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
504+
/// - ::UR_RESULT_ERROR_UNSUPPORTED_VERSION
505+
__urdlllocal ur_result_t UR_APICALL urGetKernelProcAddrTable(
506+
ur_api_version_t version, ///< [in] API version requested
507+
ur_kernel_dditable_t
508+
*pDdiTable ///< [in,out] pointer to table of DDI function pointers
509+
) {
510+
if (nullptr == pDdiTable) {
511+
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
512+
}
513+
514+
if (UR_MAJOR_VERSION(ur_sanitizer_layer::context.version) !=
515+
UR_MAJOR_VERSION(version) ||
516+
UR_MINOR_VERSION(ur_sanitizer_layer::context.version) >
517+
UR_MINOR_VERSION(version)) {
518+
return UR_RESULT_ERROR_UNSUPPORTED_VERSION;
519+
}
520+
521+
ur_result_t result = UR_RESULT_SUCCESS;
522+
523+
pDdiTable->pfnCreate = ur_sanitizer_layer::urKernelCreate;
524+
pDdiTable->pfnRelease = ur_sanitizer_layer::urKernelRelease;
525+
pDdiTable->pfnSetArgLocal = ur_sanitizer_layer::urKernelSetArgLocal;
526+
527+
return result;
528+
}
529+
///////////////////////////////////////////////////////////////////////////////
413530
/// @brief Exported function for filling application's Enqueue table
414531
/// with current process' addresses
415532
///
@@ -509,6 +626,11 @@ ur_result_t context_t::init(ur_dditable_t *dditable,
509626
UR_API_VERSION_CURRENT, &dditable->Context);
510627
}
511628

629+
if (UR_RESULT_SUCCESS == result) {
630+
result = ur_sanitizer_layer::urGetKernelProcAddrTable(
631+
UR_API_VERSION_CURRENT, &dditable->Kernel);
632+
}
633+
512634
if (UR_RESULT_SUCCESS == result) {
513635
result = ur_sanitizer_layer::urGetProgramProcAddrTable(
514636
UR_API_VERSION_CURRENT, &dditable->Program);

0 commit comments

Comments
 (0)