diff --git a/.github/pins/pti.txt b/.github/pins/pti.txt new file mode 100644 index 0000000000..ca98925e9d --- /dev/null +++ b/.github/pins/pti.txt @@ -0,0 +1 @@ +15a201d25e5659692613b98ee33513263b689101 diff --git a/.github/workflows/build-test-reusable.yml b/.github/workflows/build-test-reusable.yml index 722963a9a1..7fe0ddaa6e 100644 --- a/.github/workflows/build-test-reusable.yml +++ b/.github/workflows/build-test-reusable.yml @@ -285,9 +285,22 @@ jobs: run: | echo "TRITON_TEST_CMD=${{ needs.build.outputs.test-triton-command }}" | tee -a $GITHUB_ENV - - name: Run Proton tests + - name: Build PTI && Run Proton tests if: matrix.suite == 'rest' && inputs.driver_version == 'rolling' && inputs.device == 'max1100' run: | + PTI_COMMIT_ID="$(<.github/pins/pti.txt)" + git clone https://github.com/intel/pti-gpu.git + cd pti-gpu + git checkout $PTI_COMMIT_ID + cd sdk + cmake --preset linux-icpx-release + BUILD_TESTING=1 PTI_BUILD_SAMPLES=1 cmake --build --preset linux-icpx-release + + PTI_LIBS_DIR="$(pwd)/build-linux-icpx-release/lib/" + cd ../.. + + export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH + export TRITON_XPUPTI_LIB_PATH=$PTI_LIBS_DIR cd third_party/proton/test # FIXME: enable 'test_record.py' back pytest test_api.py test_lib.py test_profile.py test_viewer.py -s -v diff --git a/.github/workflows/triton-benchmarks.yml b/.github/workflows/triton-benchmarks.yml index 449a13bd46..f29d6ad4e9 100644 --- a/.github/workflows/triton-benchmarks.yml +++ b/.github/workflows/triton-benchmarks.yml @@ -116,9 +116,24 @@ jobs: cd benchmarks pip install . + - name: Build PTI from source + run: | + PTI_COMMIT_ID="$(<.github/pins/pti.txt)" + git clone https://github.com/intel/pti-gpu.git + cd pti-gpu + git checkout $PTI_COMMIT_ID + cd sdk + cmake --preset linux-icpx-release + BUILD_TESTING=1 PTI_BUILD_SAMPLES=1 cmake --build --preset linux-icpx-release + + PTI_LIBS_DIR="$(pwd)/build-linux-icpx-release/lib/" + ls $PTI_LIBS_DIR + echo "PTI_LIBS_DIR=$PTI_LIBS_DIR" >> $GITHUB_ENV + - name: Run Triton Softmax kernel benchmark if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'fused_softmax.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'fused_softmax.py') }} run: | + export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH cd benchmarks/triton_kernels_benchmark python fused_softmax.py --reports $REPORTS --n_runs $N_RUNS source ../../scripts/capture-hw-details.sh @@ -129,6 +144,7 @@ jobs: - name: Run Triton GEMM kernel benchmark if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'gemm_benchmark.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'gemm_benchmark.py') }} run: | + export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH cd benchmarks/triton_kernels_benchmark python gemm_benchmark.py --reports $REPORTS --n_runs $N_RUNS mv $REPORTS/matmul-performance.csv $REPORTS/matmul-performance-base.csv @@ -142,6 +158,7 @@ jobs: - name: Run Triton GEMM kernel benchmark - with tensor of pointer if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'gemm_tensor_of_ptr_benchmark.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'gemm_tensor_of_ptr_benchmark.py') }} run: | + export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH cd benchmarks/triton_kernels_benchmark python gemm_tensor_of_ptr_benchmark.py --reports $REPORTS --n_runs $N_RUNS source ../../scripts/capture-hw-details.sh @@ -154,6 +171,7 @@ jobs: - name: Run Triton GEMM kernel benchmark - with tensor descriptor if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'gemm_tensor_desc_benchmark.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'gemm_tensor_desc_benchmark.py') }} run: | + export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH cd benchmarks/triton_kernels_benchmark python gemm_tensor_desc_benchmark.py --reports $REPORTS --n_runs $N_RUNS source ../../scripts/capture-hw-details.sh @@ -166,6 +184,7 @@ jobs: - name: Run Triton GEMM (A@B^t) kernel benchmark if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'gemm_benchmark.py_abt')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'gemm_benchmark.py_abt') }} run: | + export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH cd benchmarks/triton_kernels_benchmark TRANSPOSE_B=1 python gemm_benchmark.py --reports $REPORTS --n_runs $N_RUNS mv $REPORTS/matmul-performance.csv $REPORTS/matmul-performance-bt.csv @@ -177,6 +196,7 @@ jobs: - name: Run Triton GEMM (A^t@B) kernel benchmark if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'gemm_benchmark.py_atb')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'gemm_benchmark.py_atb') }} run: | + export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH cd benchmarks/triton_kernels_benchmark TRANSPOSE_A=1 python gemm_benchmark.py --reports $REPORTS --n_runs $N_RUNS mv $REPORTS/matmul-performance.csv $REPORTS/matmul-performance-at.csv @@ -188,6 +208,7 @@ jobs: - name: Run Triton GEMM (stream-k) kernel benchmark if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'gemm_streamk_benchmark.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'gemm_streamk_benchmark.py') }} run: | + export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH cd benchmarks/triton_kernels_benchmark python gemm_streamk_benchmark.py --reports $REPORTS --n_runs $N_RUNS source ../../scripts/capture-hw-details.sh @@ -197,6 +218,7 @@ jobs: - name: Run Triton GEMM (split-k) kernel benchmark if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'gemm_splitk_benchmark.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'gemm_splitk_benchmark.py') }} run: | + export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH cd benchmarks/triton_kernels_benchmark python gemm_splitk_benchmark.py --reports $REPORTS --n_runs $N_RUNS source ../../scripts/capture-hw-details.sh @@ -206,6 +228,7 @@ jobs: - name: Run Triton GEMM + PreOp (exp) kernel benchmark if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'gemm_preop_exp_benchmark.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'gemm_preop_exp_benchmark.py') }} run: | + export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH cd benchmarks/triton_kernels_benchmark python gemm_preop_exp_benchmark.py --reports $REPORTS --n_runs $N_RUNS source ../../scripts/capture-hw-details.sh @@ -214,6 +237,7 @@ jobs: - name: Run Triton GEMM + PostOp (Gelu) kernel benchmark if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'gemm_postop_gelu_benchmark.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'gemm_postop_gelu_benchmark.py') }} run: | + export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH cd benchmarks/triton_kernels_benchmark python gemm_postop_gelu_benchmark.py --reports $REPORTS --n_runs $N_RUNS source ../../scripts/capture-hw-details.sh @@ -222,6 +246,7 @@ jobs: - name: Run Triton GEMM + PostOp (add matrix) kernel benchmark bfloat16 if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'gemm_postop_addmatrix_benchmark_bfloat16.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'gemm_postop_addmatrix_benchmark_bfloat16.py') }} run: | + export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH cd benchmarks/triton_kernels_benchmark python gemm_postop_addmatrix_benchmark.py --reports $REPORTS --n_runs $N_RUNS source ../../scripts/capture-hw-details.sh @@ -231,6 +256,7 @@ jobs: - name: Run Triton GEMM + PostOp (add matrix) kernel benchmark int8 if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'gemm_postop_addmatrix_benchmark_int8.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'gemm_postop_addmatrix_benchmark_int8.py') }} run: | + export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH cd benchmarks/triton_kernels_benchmark INT8_ONLY=1 python gemm_postop_addmatrix_benchmark.py --reports $REPORTS --n_runs $N_RUNS source ../../scripts/capture-hw-details.sh @@ -240,6 +266,7 @@ jobs: - name: Run Triton FA fwd kernel benchmark if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'flash_attention_benchmark.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flash_attention_benchmark.py') }} run: | + export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH cd benchmarks/triton_kernels_benchmark python flash_attention_benchmark.py --reports $REPORTS --n_runs $N_RUNS @@ -250,6 +277,7 @@ jobs: - name: Run Triton FA bwd kernel benchmark if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'flash_attention_bwd_benchmark.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flash_attention_bwd_benchmark.py') }} run: | + export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH cd benchmarks/triton_kernels_benchmark FA_KERNEL_MODE="bwd" \ python flash_attention_benchmark.py --reports $REPORTS --n_runs $N_RUNS @@ -262,6 +290,7 @@ jobs: - name: Run Triton FA fwd kernel benchmark - with tensor descriptors if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'flash_attention_tensor_desc_benchmark.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flash_attention_tensor_desc_benchmark.py') }} run: | + export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH cd benchmarks/triton_kernels_benchmark python flash_attention_tensor_desc_benchmark.py --reports $REPORTS --n_runs $N_RUNS mv $REPORTS/attn-performance.csv $REPORTS/attn-tensor-desc-performance.csv @@ -273,6 +302,7 @@ jobs: - name: Run Triton FlexAttention Causal Mask fwd kernel benchmark if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'flex_attention_benchmark_causal_mask.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flex_attention_benchmark_causal_mask.py') }} run: | + export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH cd benchmarks/triton_kernels_benchmark python flex_attention_benchmark_causal_mask.py --reports $REPORTS --n_runs $N_RUNS @@ -283,6 +313,7 @@ jobs: - name: Run Triton FlexAttention (batch_size=4) Causal Mask fwd kernel benchmark if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'flex_attention_benchmark_batch4-causal_mask.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flex_attention_benchmark_batch4-causal_mask.py') }} run: | + export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH cd benchmarks/triton_kernels_benchmark BATCH_SIZE=4 python flex_attention_benchmark_causal_mask.py --reports $REPORTS --n_runs $N_RUNS @@ -293,6 +324,7 @@ jobs: - name: Run Triton FlexAttention (batch_size=16) Causal Mask fwd kernel benchmark if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'flex_attention_benchmark_batch16-causal_mask.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flex_attention_benchmark_batch16-causal_mask.py') }} run: | + export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH cd benchmarks/triton_kernels_benchmark BATCH_SIZE=16 python flex_attention_benchmark_causal_mask.py --reports $REPORTS --n_runs $N_RUNS @@ -303,6 +335,7 @@ jobs: - name: Run Triton FlexAttention Custom Masks fwd kernel benchmark if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'flex_attention_benchmark_custom_masks.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flex_attention_benchmark_custom_masks.py') }} run: | + export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH cd benchmarks/triton_kernels_benchmark python flex_attention_benchmark_custom_masks.py --reports $REPORTS --n_runs $N_RUNS @@ -316,6 +349,7 @@ jobs: - name: Run Prefix Sums kernel benchmark if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'prefix_sums.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'prefix_sums.py') }} run: | + export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH cd benchmarks/triton_kernels_benchmark python prefix_sums.py --reports $REPORTS --n_runs $N_RUNS source ../../scripts/capture-hw-details.sh @@ -324,6 +358,7 @@ jobs: - name: Run micro benchmark if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'micro_benchmarks.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'micro_benchmarks') }} run: | + export LD_LIBRARY_PATH=$PTI_LIBS_DIR:$LD_LIBRARY_PATH cd benchmarks/micro_benchmarks python run_benchmarks.py --reports $REPORTS diff --git a/third_party/intel/backend/proton/include/pti/pti.h b/third_party/intel/backend/proton/include/pti/pti.h index 3bd6a3d363..512a839154 100644 --- a/third_party/intel/backend/proton/include/pti/pti.h +++ b/third_party/intel/backend/proton/include/pti/pti.h @@ -31,7 +31,9 @@ typedef enum { //!< PTI_VIEW_EXTERNAL_CORRELATION PTI_ERROR_BAD_TIMESTAMP = 6, //!< error in timestamp conversion, might be related with the user //!< provided TimestampCallback - PTI_ERROR_BAD_API_ID = 7, //!< invalid api_id when enable/disable runtime/driver specific api_id + PTI_ERROR_BAD_API_ID = 7, //!< invalid api_id when enable/disable runtime/driver specific api_id + PTI_ERROR_NO_GPU_VIEWS_ENABLED = 8, //!< at least one GPU view must be enabled for kernel tracing + PTI_ERROR_DRIVER = 50, //!< unknown driver error PTI_ERROR_TRACING_NOT_INITIALIZED = 51, //!< installed driver requires tracing enabling with //!< setting environment variable ZE_ENABLE_TRACING_LAYER @@ -57,6 +59,25 @@ typedef enum { */ PTI_EXPORT const char* ptiResultTypeToString(pti_result result_value); + +/** + * @brief Abstraction for backend-specific objects. + * + * Level Zero is currently the only supported backend. However, these types will attempt to serve other backends. + * In case the other backend supported - the same types will serve it. + */ + +typedef void* pti_device_handle_t; //!< Device handle + +typedef void* pti_backend_ctx_t; //!< Backend context handle + +typedef void* pti_backend_queue_t; //!< Backend queue handle + +typedef void* pti_backend_evt_t; //!< Backend event handle + +typedef void* pti_backend_command_list_t; //!< Backend command list handle + + #if defined(__cplusplus) } #endif diff --git a/third_party/intel/backend/proton/include/pti/pti_callback.h b/third_party/intel/backend/proton/include/pti/pti_callback.h new file mode 100644 index 0000000000..0659cec7fa --- /dev/null +++ b/third_party/intel/backend/proton/include/pti/pti_callback.h @@ -0,0 +1,234 @@ +//============================================================== +// Copyright (C) Intel Corporation +// +// SPDX-License-Identifier: MIT +// ============================================================= + +#ifndef PTI_CALLBACK_H_ +#define PTI_CALLBACK_H_ + +#include + +#include "pti/pti.h" +#include "pti/pti_view.h" + +/** + * This file contains APIs that are so far experimental in PTI. + * APIs and data structures in this file are work-in-progress and subject to change! + * All content in this file concerns the Callback API. + * + * The Callback API is useful for many purposes, + * including the implementation of `MetricsScope` functionality that needs to subscribe to + * domains such as kernel append to a command list, and potentially other domains. + * The `MetricsScope` API is under development and is the first (internal) user of the Callback API. + */ + + +/* clang-format off */ +#if defined(__cplusplus) +extern "C" { +#endif + +typedef struct _pti_callback_subscriber* pti_callback_subscriber_handle; + +typedef enum _pti_callback_domain { + PTI_CB_DOMAIN_INVALID = 0, + PTI_CB_DOMAIN_DRIVER_CONTEXT_CREATED = 1, //!< Not implemented yet + //!< attempt to enable it will return PTI_ERROR_NOT_IMPLEMENTED + + PTI_CB_DOMAIN_DRIVER_MODULE_LOADED = 2, //!< Not implemented yet + //!< attempt to enable it will return PTI_ERROR_NOT_IMPLEMENTED + + PTI_CB_DOMAIN_DRIVER_MODULE_UNLOADED = 3, //!< Not implemented yet + //!< attempt to enable it will return PTI_ERROR_NOT_IMPLEMENTED + + PTI_CB_DOMAIN_DRIVER_GPU_OPERATION_APPENDED = 4, //!< Synchronous callback + //!< This also serves as PTI_CB_DOMAIN_DRIVER_GPU_OPERATION_DISPATCHED + //!< when appended to Immediate Command List, + //!< which means no separate callback PTI_CB_DOMAIN_DRIVER_GPU_OPERATION_DISPATCHED + + PTI_CB_DOMAIN_DRIVER_GPU_OPERATION_DISPATCHED = 5, //!< Not implemented yet + //!< attempt to enable it will return PTI_ERROR_NOT_IMPLEMENTED + + PTI_CB_DOMAIN_DRIVER_GPU_OPERATION_COMPLETED = 6, //!< Asynchronous callback, always has only EXIT phase of some API, + //!< where completed operations are collected and reported + + PTI_CB_DOMAIN_DRIVER_HOST_SYNCHRONIZATION = 7, //!< Not implemented yet + //!< attempt to enable it will return PTI_ERROR_NOT_IMPLEMENTED + + PTI_CB_DOMAIN_DRIVER_API = 1023, //!< Not implemented yet, + //!< attempt to enable it will return PTI_ERROR_NOT_IMPLEMENTED + //!< Callback created for all Driver APIs + // below domains to inform user about PTI internal events + PTI_CB_DOMAIN_INTERNAL_THREADS = 1024, //!< Not implemented yet + PTI_CB_DOMAIN_INTERNAL_EVENT = 1025, //!< Not implemented yet + + PTI_CB_DOMAIN_MAX = 0x7fffffff +} pti_callback_domain; + +typedef enum _pti_callback_phase { + PTI_CB_PHASE_INVALID = 0, + PTI_CB_PHASE_API_ENTER = 1, + PTI_CB_PHASE_API_EXIT = 2, + PTI_CB_PHASE_INTERNAL_THREAD_START = 3, + PTI_CB_PHASE_INTERNAL_THREAD_END = 4, + PTI_CB_PHASE_INTERNAL_EVENT = 5, + + PTI_CB_PHASE_MAX = 0x7fffffff +} pti_callback_phase; + +typedef enum _pti_backend_command_list_type { + PTI_BACKEND_COMMAND_LIST_TYPE_UNKNOWN = (1<<0), + PTI_BACKEND_COMMAND_LIST_TYPE_IMMEDIATE = (1<<1), + PTI_BACKEND_COMMAND_LIST_TYPE_MUTABLE = (1<<2), + + PTI_BACKEND_COMMAND_LIST_TYPE_MAX = 0x7fffffff +} pti_backend_command_list_type; + +/** + * A user can subscribe to notifications about non-standard situations from PTI + * when it collects or processes the data + */ +typedef enum _pti_internal_event_type { + PTI_INTERNAL_EVENT_TYPE_INFO = 0, + PTI_INTERNAL_EVENT_TYPE_WARNING = 1, // one or a few records data inconsistencies, or other + // collection is safe to continue + PTI_INTERNAL_EVENT_TYPE_CRITICAL = 2, // critical error after which further collected data are invalid + + PTI_INTERNAL_EVENT_TYPE_MAX = 0x7fffffff +} pti_internal_event_type; + +typedef enum _pti_gpu_operation_kind { + PTI_GPU_OPERATION_KIND_INVALID = 0, + PTI_GPU_OPERATION_KIND_KERNEL = 1, + PTI_GPU_OPERATION_KIND_MEMORY = 2, + PTI_GPU_OPERATION_KIND_OTHER = 3, + + PTI_GPU_OPERATION_KIND_MAX = 0x7fffffff +} pti_gpu_operation_kind; + +typedef struct _pti_gpu_op_details { + pti_gpu_operation_kind _operation_kind; // #include namespace proton { @@ -15,6 +16,22 @@ template pti_result viewDisable(pti_view_kind kind); template pti_result viewFlushAll(); +template +pti_result subscribe(pti_callback_subscriber_handle *subscriber, + pti_callback_function callback, void *user_data); + +template +pti_result unsubscribe(pti_callback_subscriber_handle subscriber); + +template +pti_result enableDomain(pti_callback_subscriber_handle subscriber, + pti_callback_domain domain, uint32_t enter_cb, + uint32_t exit_cb); + +template +pti_result disableDomain(pti_callback_subscriber_handle subscriber, + pti_callback_domain domain); + template pti_result viewGetNextRecord(uint8_t *buffer, size_t valid_bytes, pti_view_record_base **record); diff --git a/third_party/proton/csrc/lib/Driver/GPU/XpuptiApi.cpp b/third_party/proton/csrc/lib/Driver/GPU/XpuptiApi.cpp index 726199781d..c618661372 100644 --- a/third_party/proton/csrc/lib/Driver/GPU/XpuptiApi.cpp +++ b/third_party/proton/csrc/lib/Driver/GPU/XpuptiApi.cpp @@ -10,6 +10,7 @@ struct ExternLibXpupti : public ExternLibBase { using RetType = pti_result; static constexpr const char *name = "libpti_view.so"; static constexpr const char *defaultDir = ""; + static constexpr const char *pathEnv = "TRITON_XPUPTI_LIB_PATH"; static constexpr RetType success = PTI_SUCCESS; static void *lib; }; @@ -24,6 +25,19 @@ DEFINE_DISPATCH(ExternLibXpupti, viewDisable, ptiViewDisable, pti_view_kind) DEFINE_DISPATCH(ExternLibXpupti, viewFlushAll, ptiFlushAllViews) +DEFINE_DISPATCH(ExternLibXpupti, subscribe, ptiCallbackSubscribe, + pti_callback_subscriber_handle *, pti_callback_function, void *) + +DEFINE_DISPATCH(ExternLibXpupti, unsubscribe, ptiCallbackUnsubscribe, + pti_callback_subscriber_handle); + +DEFINE_DISPATCH(ExternLibXpupti, enableDomain, ptiCallbackEnableDomain, + pti_callback_subscriber_handle, pti_callback_domain, uint32_t, + uint32_t); + +DEFINE_DISPATCH(ExternLibXpupti, disableDomain, ptiCallbackDisableDomain, + pti_callback_subscriber_handle, pti_callback_domain); + DEFINE_DISPATCH(ExternLibXpupti, viewGetNextRecord, ptiViewGetNextRecord, uint8_t *, size_t, pti_view_record_base **) diff --git a/third_party/proton/csrc/lib/Profiler/Xpupti/XpuptiProfiler.cpp b/third_party/proton/csrc/lib/Profiler/Xpupti/XpuptiProfiler.cpp index 48ba83f564..e8651323dd 100644 --- a/third_party/proton/csrc/lib/Profiler/Xpupti/XpuptiProfiler.cpp +++ b/third_party/proton/csrc/lib/Profiler/Xpupti/XpuptiProfiler.cpp @@ -264,14 +264,17 @@ struct XpuptiProfiler::XpuptiProfilerPimpl static void allocBuffer(uint8_t **buffer, size_t *bufferSize); static void completeBuffer(uint8_t *buffer, size_t size, size_t validSize); - /* - static void callbackFn(void *userData, CUpti_CallbackDomain domain, - CUpti_CallbackId cbId, const void *cbData); - */ + static void callbackFn(pti_callback_domain domain, + pti_api_group_id driver_api_group_id, + uint32_t driver_api_id, + pti_backend_ctx_t backend_context, void *cb_data, + void *global_user_data, void **instance_user_data); static constexpr size_t AlignSize = 8; static constexpr size_t BufferSize = 64 * 1024 * 1024; + pti_callback_subscriber_handle subscriber; + /* static constexpr size_t AttributeSize = sizeof(size_t); @@ -327,6 +330,53 @@ void XpuptiProfiler::XpuptiProfilerPimpl::completeBuffer(uint8_t *buffer, profiler.correlation.complete(maxCorrelationId); } +void XpuptiProfiler::XpuptiProfilerPimpl::callbackFn( + pti_callback_domain domain, pti_api_group_id driver_api_group_id, + uint32_t driver_api_id, pti_backend_ctx_t backend_context, void *cb_data, + void *global_user_data, void **instance_user_data) { + std::cout << "callback\n" << std::flush; + pti_callback_gpu_op_data *callback_data = + static_cast(cb_data); + if (callback_data == nullptr) { + std::cerr << "CallbackGPUOperationAppend: callback_data is null" + << std::endl; + return; + } + if (callback_data->_phase == PTI_CB_PHASE_API_ENTER) { + threadState.enterOp(); + threadState.profiler.correlation.correlate(callback_data->_correlation_id, + 1); + } else if (callback_data->_phase == PTI_CB_PHASE_API_EXIT) { + threadState.exitOp(); + threadState.profiler.correlation.submit(callback_data->_correlation_id); + } else { + throw std::runtime_error("[PROTON] callbackFn failed"); + } +} + +void CallbackCommon(pti_callback_domain domain, + pti_api_group_id driver_group_id, uint32_t driver_api_id, + [[maybe_unused]] pti_backend_ctx_t backend_context, + [[maybe_unused]] void *cb_data, + [[maybe_unused]] void *user_data) { + + switch (domain) { + case PTI_CB_DOMAIN_DRIVER_GPU_OPERATION_APPENDED: + std::cout << "PTI_CB_DOMAIN_DRIVER_GPU_OPERATION_APPENDED\n" << std::flush; + break; + case PTI_CB_DOMAIN_DRIVER_GPU_OPERATION_COMPLETED: + std::cout << "PTI_CB_DOMAIN_DRIVER_GPU_OPERATION_COMPLETED\n" << std::flush; + break; + default: { + std::cout << "In " << __func__ << ", domain: " << domain + << ", driver_group_id: " << driver_group_id + << ", driver_api_id: " << driver_api_id << std::endl; + break; + } + } + std::cout << std::endl; +} + zel_tracer_handle_t tracer = nullptr; typedef void (*EnumDeviceUUIDsFunc)(std::vector>); @@ -380,7 +430,6 @@ int callWaitOnSyclQueue(const std::string &utils_cache_path, void *syclQueue) { } void XpuptiProfiler::XpuptiProfilerPimpl::doStart() { - // xpupti::subscribe(&subscriber, callbackFn, nullptr); // should be call to shared lib XpuptiProfiler &profiler = threadState.profiler; if (profiler.utils_cache_path != "") { @@ -389,13 +438,13 @@ void XpuptiProfiler::XpuptiProfilerPimpl::doStart() { // auto res = ptiViewPushExternalCorrelationId( // pti_view_external_kind::PTI_VIEW_EXTERNAL_KIND_CUSTOM_1, 42); // std::cout << "res: " << res << "\n" << std::flush; - + /* ze_result_t status = ZE_RESULT_SUCCESS; // status = zeInit(ZE_INIT_FLAG_GPU_ONLY); // assert(status == ZE_RESULT_SUCCESS); zel_tracer_desc_t tracer_desc = {ZEL_STRUCTURE_TYPE_TRACER_DESC, nullptr, - nullptr /* global user data */}; + nullptr}; status = zelTracerCreate(&tracer_desc, &tracer); std::cout << "zelTracerCreate: " << status << "\n" << std::flush; @@ -417,9 +466,13 @@ void XpuptiProfiler::XpuptiProfilerPimpl::doStart() { status = zelTracerSetEnabled(tracer, true); assert(status == ZE_RESULT_SUCCESS); + */ xpupti::viewSetCallbacks(allocBuffer, completeBuffer); xpupti::viewEnable(PTI_VIEW_DEVICE_GPU_KERNEL); + xpupti::viewEnable(PTI_VIEW_DEVICE_GPU_MEM_FILL); + xpupti::viewEnable(PTI_VIEW_DEVICE_GPU_MEM_COPY); + xpupti::subscribe(&subscriber, callbackFn, &subscriber); // xpupti::viewEnable(PTI_VIEW_DEVICE_GPU_MEM_COPY); // xpupti::viewEnable(PTI_VIEW_DEVICE_GPU_MEM_FILL); // xpupti::viewEnable(PTI_VIEW_SYCL_RUNTIME_CALLS); @@ -428,6 +481,8 @@ void XpuptiProfiler::XpuptiProfilerPimpl::doStart() { // xpupti::viewEnable(PTI_VIEW_LEVEL_ZERO_CALLS); // setGraphCallbacks(subscriber, /*enable=*/true); // setRuntimeCallbacks(subscriber, /*enable=*/true); + xpupti::enableDomain(subscriber, + PTI_CB_DOMAIN_DRIVER_GPU_OPERATION_APPENDED, 1, 1); // setDriverCallbacks(subscriber, /*enable=*/true); } @@ -444,13 +499,17 @@ void XpuptiProfiler::XpuptiProfilerPimpl::doFlush() { } void XpuptiProfiler::XpuptiProfilerPimpl::doStop() { + /* ze_result_t status = ZE_RESULT_SUCCESS; status = zelTracerSetEnabled(tracer, false); assert(status == ZE_RESULT_SUCCESS); status = zelTracerDestroy(tracer); assert(status == ZE_RESULT_SUCCESS); + */ xpupti::viewDisable(PTI_VIEW_DEVICE_GPU_KERNEL); + xpupti::viewDisable(PTI_VIEW_DEVICE_GPU_MEM_FILL); + xpupti::viewDisable(PTI_VIEW_DEVICE_GPU_MEM_COPY); // xpupti::viewDisable(PTI_VIEW_DEVICE_GPU_MEM_COPY); // xpupti::viewDisable(PTI_VIEW_DEVICE_GPU_MEM_FILL); // xpupti::viewDisable(PTI_VIEW_SYCL_RUNTIME_CALLS); @@ -460,7 +519,9 @@ void XpuptiProfiler::XpuptiProfilerPimpl::doStop() { // setGraphCallbacks(subscriber, /*enable=*/false); // setRuntimeCallbacks(subscriber, /*enable=*/false); // setDriverCallbacks(subscriber, /*enable=*/false); - // cupti::unsubscribe(subscriber); + xpupti::disableDomain(subscriber, + PTI_CB_DOMAIN_DRIVER_GPU_OPERATION_APPENDED); + xpupti::unsubscribe(subscriber); // cupti::finalize(); } diff --git a/third_party/proton/test/test_profile.py b/third_party/proton/test/test_profile.py index f85a4f77d2..15a97dc578 100644 --- a/third_party/proton/test/test_profile.py +++ b/third_party/proton/test/test_profile.py @@ -258,8 +258,6 @@ def foo(x, size: tl.constexpr, y): def test_hook_with_third_party(tmp_path: pathlib.Path): - if is_xpu(): - pytest.skip("FIXME: enable") third_party_hook_invoked = False def third_party_hook(metadata) -> None: @@ -280,7 +278,7 @@ def foo(x, size: tl.constexpr, y): offs = tl.arange(0, size) tl.store(y + offs, tl.load(x + offs)) - x = torch.tensor([2], device="cuda", dtype=torch.float32) + x = torch.tensor([2], device="xpu", dtype=torch.float32) y = torch.zeros_like(x) temp_file = tmp_path / "test_hook_with_third_party.hatchet" proton.start(str(temp_file.with_suffix("")), hook="triton") @@ -295,8 +293,6 @@ def foo(x, size: tl.constexpr, y): def test_hook_multiple_threads(tmp_path: pathlib.Path): - if is_xpu(): - pytest.skip("FIXME: enable") def metadata_fn_foo(grid: tuple, metadata: NamedTuple, args: dict): return {"name": "foo_test"} @@ -314,9 +310,9 @@ def bar(x, size: tl.constexpr, y): offs = tl.arange(0, size) tl.store(y + offs, tl.load(x + offs)) - x_foo = torch.tensor([2], device="cuda", dtype=torch.float32) + x_foo = torch.tensor([2], device="xpu", dtype=torch.float32) y_foo = torch.zeros_like(x_foo) - x_bar = torch.tensor([2], device="cuda", dtype=torch.float32) + x_bar = torch.tensor([2], device="xpu", dtype=torch.float32) y_bar = torch.zeros_like(x_bar) temp_file = tmp_path / "test_hook.hatchet" @@ -410,10 +406,6 @@ def test_deactivate(tmp_path: pathlib.Path): def test_multiple_sessions(tmp_path: pathlib.Path): - if is_xpu(): - # FIXME: the same correlation id, that's why it's filtered, - # should `_kernel_id` be used instead - pytest.xfail('assert int(data[0]["children"][0]["metrics"]["count"]) == 2') temp_file0 = tmp_path / "test_multiple_sessions0.hatchet" temp_file1 = tmp_path / "test_multiple_sessions1.hatchet" session_id0 = proton.start(str(temp_file0.with_suffix(""))) @@ -439,8 +431,6 @@ def test_multiple_sessions(tmp_path: pathlib.Path): def test_trace(tmp_path: pathlib.Path): - if is_xpu(): - pytest.skip("FIXME: enable") temp_file = tmp_path / "test_trace.chrome_trace" proton.start(str(temp_file.with_suffix("")), data="trace") @@ -450,7 +440,7 @@ def foo(x, y, size: tl.constexpr): tl.store(y + offs, tl.load(x + offs)) with proton.scope("init"): - x = torch.ones((1024, ), device="cuda", dtype=torch.float32) + x = torch.ones((1024, ), device="xpu", dtype=torch.float32) y = torch.zeros_like(x) with proton.scope("test"): @@ -467,8 +457,6 @@ def foo(x, y, size: tl.constexpr): def test_scope_multiple_threads(tmp_path: pathlib.Path): - if is_xpu(): - pytest.skip("FIXME: enable") temp_file = tmp_path / "test_scope_threads.hatchet" proton.start(str(temp_file.with_suffix(""))) @@ -479,7 +467,7 @@ def worker(prefix: str): for i in range(N): name = f"{prefix}_{i}" proton.enter_scope(name) - torch.ones((1, ), device="cuda") + torch.ones((1, ), device="xpu") proton.exit_scope() threads = [threading.Thread(target=worker, args=(tname, )) for tname in thread_names]